stdlib_specialmatrices_sym_tridiagonal.fypp Source File


Source Code

#:include "common.fypp"
#:set RANKS = range(1, 2+1)
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES
submodule (stdlib_specialmatrices) sym_tridiagonal_matrices
    use stdlib_linalg_lapack, only: lagtm
    use stdlib_optval, only: optval

    character(len=*), parameter :: this = "symmetric tridiagonal matrices"
    contains

    !--------------------------------
    !-----                      -----
    !-----     CONSTRUCTORS     -----
    !-----                      -----
    !--------------------------------

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function initialize_sym_tridiagonal_pure_${s1}$(du, dv) result(A)
        !! Construct a `symmetric tridiagonal` matrix from the rank-1 arrays `du` and `dv`.
        ${t1}$, intent(in) :: du(:), dv(:)
        !! symmetric tridiagonal matrix elements
        type(sym_tridiagonal_${s1}$_type) :: A
        !! Corresponding symmetric tridiagonal matrix.

        call build_sym_tridiagonal(du, dv, A)
    end function

    pure module function initialize_constant_sym_tridiagonal_pure_${s1}$(du, dv, n) result(A)
        !! Construct a `symmetric tridiagonal` matrix with scalar elements.
        ${t1}$, intent(in) :: du, dv
        !! symmetric tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(sym_tridiagonal_${s1}$_type) :: A
        !! Corresponding symmetric tridiagonal matrix.

        call build_sym_tridiagonal(du, dv, n, A)
    end function

    module function initialize_sym_tridiagonal_impure_${s1}$(du, dv, err) result(A)
        !! Construct a `symmetric tridiagonal` matrix from the rank-1 arrays `du` and `dv`.
        ${t1}$, intent(in) :: du(:), dv(:)
        !! symmetric tridiagonal matrix elements
        type(linalg_state_type), intent(out) :: err
        !! Error handling.
        type(sym_tridiagonal_${s1}$_type) :: A
        !! Corresponding symmetric tridiagonal matrix.

        call build_sym_tridiagonal(du, dv, A, err)
    end function

    module function initialize_constant_sym_tridiagonal_impure_${s1}$(du, dv, n, err) result(A)
        !! Construct a `symmetric tridiagonal` matrix with scalar elements.
        ${t1}$, intent(in) :: du, dv
        !! symmetric tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(linalg_state_type), intent(out) :: err
        !! Error handling.
        type(sym_tridiagonal_${s1}$_type) :: A
        !! Corresponding symmetric tridiagonal matrix.

        call build_sym_tridiagonal(du, dv, n, A, err)
        end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module subroutine build_sym_tridiagonal_from_arrays_${s1}$(du, dv, A, err)
        ${t1}$, intent(in) :: du(:), dv(:)
        type(sym_tridiagonal_${s1}$_type), intent(out) :: A
        type(linalg_state_type), intent(out), optional :: err

        ! Internal variables.
        integer(ilp) :: n
        type(linalg_state_type) :: err0

        ! Sanity check.
        n = size(dv, kind=ilp)
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif
        if (size(du, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
            call linalg_error_handling(err0, err)
        endif

        if(err0%ok()) then
            ! Description of the matrix.
            A%n = n
            ! Matrix elements.
            A%du = du
            A%dv = dv
        endif
    end subroutine

    pure module subroutine build_sym_tridiagonal_from_constants_${s1}$(du, dv, n, A, err)
        ${t1}$, intent(in) :: du, dv
        integer(ilp), intent(in) :: n
        type(sym_tridiagonal_${s1}$_type), intent(out) :: A
        type(linalg_state_type), intent(out), optional :: err

        ! Internal variables.
        type(linalg_state_type) :: err0

        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif

        if(err0%ok()) then
            ! Description of the matrix.
            A%n = n
            ! Matrix elements.
            allocate( A%du(n-1), source = du )
            allocate( A%dv(n), source= dv )
        endif
    end subroutine
    #:endfor

    !-----------------------------------------
    !-----                               -----
    !-----     MATRIX-VECTOR PRODUCT     -----
    !-----                               -----
    !-----------------------------------------

    !! spmv_sym_tridiag
    #:for k1, t1, s1 in KINDS_TYPES
    #:for rank in RANKS
    module subroutine spmv_sym_tridiag_${rank}$d_${s1}$(A, x, y, alpha, beta, op)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
        ${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
        ${t1}$, intent(in), optional :: alpha
        ${t1}$, intent(in), optional :: beta
        character(1), intent(in), optional :: op

        ! Internal variables.
        ${t1}$ :: alpha_, beta_
        integer(ilp) :: n, nrhs, ldx, ldy
        character(1) :: op_

        type(linalg_state_type) :: err0
        #:if t1.startswith('real')
        logical :: is_alpha_special, is_beta_special
        #:endif

        ${t1}$, pointer :: xmat(:, :), ymat(:, :)

        if(present(op)) then
            if(.not.(op == "N" .or. op == "T" .or. op == "C" .or. op == "H")) then
                err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid matrix operation; expected 'N', 'T', 'C' or 'H'.")
                call linalg_error_handling(err0)
            end if
        end if

        ! Deal with optional arguments.
        alpha_ = optval(alpha, one_${s1}$)
        beta_ = optval(beta, zero_${s1}$)
        op_ = optval(op, "N")
        if (op_ == "H") op_ = "C"

        #:if t1.startswith('real')
        is_alpha_special = (alpha_ ==  1.0_${k1}$  .or. alpha_ ==  0.0_${k1}$  .or. alpha_ == -1.0_${k1}$)
        is_beta_special  = (beta_  ==  1.0_${k1}$  .or. beta_  ==  0.0_${k1}$  .or. beta_  == -1.0_${k1}$)
        #:endif

        ! Prepare Lapack arguments.
        n = A%n
        ldx = n
        ldy = n
        nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}#

        ! Pointer trick.
        xmat(1:n, 1:nrhs) => x
        ymat(1:n, 1:nrhs) => y
        #:if t1.startswith('complex')
        call glagtm(op_, n, nrhs, alpha_, A%du, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        #:else
        if(is_alpha_special .and. is_beta_special) then
            call lagtm(op_, n, nrhs, alpha_, A%du, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        else
            call glagtm(op_, n, nrhs, alpha_, A%du, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        end if
        #:endif
    end subroutine
    #:endfor
    #:endfor

    !-------------------------------------
    !-----                           -----
    !-----     UTILITY FUNCTIONS     -----
    !-----                           -----
    !-------------------------------------
    #:for k1, t1, s1 in KINDS_TYPES
    pure module function sym_tridiagonal_to_dense_${s1}$(A) result(B)
        !! Convert a `symmetric tridiagonal` matrix to its dense representation.
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        ${t1}$, allocatable :: B(:, :)
        !! Corresponding dense matrix.

        ! Internal variables.
        integer(ilp) :: i

        associate (n => A%n)
            #:if t1.startswith('complex')
            allocate(B(n,n), source=zero_c${k1}$)
            #:else
            allocate(B(n,n), source=zero_${k1}$)
            #:endif
            if(n == 1) then
                B(1,1) = A%dv(1)
            else
                B(1,1) = A%dv(1)
                B(1,2) = A%du(1)
                do concurrent (i = 2: n - 1)
                    B(i, i - 1) = A%du(i - 1)
                    B(i, i) = A%dv(i)
                    B(i, i + 1) = A%du(i)
                enddo
                B(n , n -1) = A%du(n - 1)
                B(n, n) = A%dv(n)
            end if
        end associate
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function transpose_sym_tridiagonal_${s1}$(A) result(B)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(sym_tridiagonal_${s1}$_type) :: B
        B = sym_tridiagonal(A%du, A%dv)
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function hermitian_sym_tridiagonal_${s1}$(A) result(B)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(sym_tridiagonal_${s1}$_type) :: B
        #:if t1.startswith("complex")
        B = sym_tridiagonal(conjg(A%du), conjg(A%dv))
        #:else
        B = sym_tridiagonal(A%du, A%dv)
        #:endif
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function scalar_multiplication_sym_tridiagonal_${s1}$(alpha, A) result(B)
        ${t1}$, intent(in) :: alpha
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        type(sym_tridiagonal_${s1}$_type) :: B
        B = sym_tridiagonal(A%du, A%dv)
        B%du = alpha*B%du
        B%dv = alpha*B%dv
    end function

    pure module function scalar_multiplication_bis_sym_tridiagonal_${s1}$(A, alpha) result(B)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in) :: alpha
        type(sym_tridiagonal_${s1}$_type) :: B
        B = sym_tridiagonal(A%du, A%dv)
        B%du = alpha*B%du
        B%dv = alpha*B%dv
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function matrix_add_sym_tridiagonal_${s1}$(A, B) result(C)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        type(sym_tridiagonal_${s1}$_type), intent(in) :: B
        type(sym_tridiagonal_${s1}$_type) :: C
        C = sym_tridiagonal(A%du, A%dv)
        C%du = C%du + B%du
        C%dv = C%dv + B%dv
    end function

    pure module function matrix_sub_sym_tridiagonal_${s1}$(A, B) result(C)
        type(sym_tridiagonal_${s1}$_type), intent(in) :: A
        type(sym_tridiagonal_${s1}$_type), intent(in) :: B
        type(sym_tridiagonal_${s1}$_type) :: C
        C = sym_tridiagonal(A%du, A%dv)
        C%du = C%du - B%du
        C%dv = C%dv - B%dv
    end function
    #:endfor

end submodule