stdlib_specialmatrices_tridiagonal.fypp Source File


Source Code

#:include "common.fypp"
#:set RANKS = range(1, 2+1)
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
submodule (stdlib_specialmatrices) tridiagonal_matrices
    use stdlib_linalg_lapack, only: lagtm
    use stdlib_optval, only: optval

    character(len=*), parameter :: this = "tridiagonal matrices"
    contains

    !--------------------------------
    !-----                      -----
    !-----     CONSTRUCTORS     -----
    !-----                      -----
    !--------------------------------

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function initialize_tridiagonal_pure_${s1}$(dl, dv, du) result(A)
        !! Construct a `tridiagonal` matrix from the rank-1 arrays
        !! `dl`, `dv` and `du`.
        ${t1}$, intent(in) :: dl(:), dv(:), du(:)
        !! tridiagonal matrix elements.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.
        call build_tridiagonal(dl, dv, du, A)
    end function

    pure module function initialize_constant_tridiagonal_pure_${s1}$(dl, dv, du, n) result(A)
        !! Construct a `tridiagonal` matrix with scalar elements.
        ${t1}$, intent(in) :: dl, dv, du
        !! tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.
        call build_tridiagonal(dl, dv, du, n, A)
    end function

    module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
        !! Construct a `tridiagonal` matrix from the rank-1 arrays
        !! `dl`, `dv` and `du`.
        ${t1}$, intent(in) :: dl(:), dv(:), du(:)
        !! tridiagonal matrix elements.
        type(linalg_state_type), intent(out) :: err
        !! Error handling.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.

        call build_tridiagonal(dl, dv, du, A, err)
    end function

    module function initialize_constant_tridiagonal_impure_${s1}$(dl, dv, du, n, err) result(A)
        !! Construct a `tridiagonal` matrix with scalar elements.
        ${t1}$, intent(in) :: dl, dv, du
        !! tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(linalg_state_type), intent(out) :: err
        !! Error handling
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.
        call build_tridiagonal(dl, dv, du, n, A, err)
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module subroutine build_tridiagonal_from_arrays_${s1}$(dl, dv, du, A, err)
        ${t1}$, intent(in) :: dl(:), dv(:), du(:)
        type(tridiagonal_${s1}$_type), intent(out) :: A
        type(linalg_state_type), intent(out), optional :: err

        ! Internal variables.
        integer(ilp) :: n
        type(linalg_state_type) :: err0

        ! Sanity check.
        n = size(dv, kind=ilp)
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif
        if (size(dl, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.")
            call linalg_error_handling(err0, err)
        endif
        if (size(du, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
            call linalg_error_handling(err0, err)
        endif

        if(err0%ok()) then
            ! Description of the matrix.
            A%n = n
            ! Matrix elements.
            A%dl = dl
            A%dv = dv
            A%du = du
        endif
    end subroutine

    pure module subroutine build_tridiagonal_from_constants_${s1}$(dl, dv, du, n, A, err)
        ${t1}$, intent(in) :: dl, dv, du
        integer(ilp), intent(in) :: n
        type(tridiagonal_${s1}$_type), intent(out) :: A
        type(linalg_state_type), intent(out), optional :: err

        ! Internal variables.
        type(linalg_state_type) :: err0

        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif

        if(err0%ok()) then
            ! Description of the matrix.
            A%n = n
            ! Matrix elements.
            allocate( A%dl(n-1), source = dl )
            allocate( A%dv(n), source= dv )
            allocate( A%du(n-1), source = du )
        endif
    end subroutine
    #:endfor

    !-----------------------------------------
    !-----                               -----
    !-----     MATRIX-VECTOR PRODUCT     -----
    !-----                               -----
    !-----------------------------------------

    !! spmv_tridiag
    #:for k1, t1, s1 in KINDS_TYPES
    #:for rank in RANKS
    module subroutine spmv_tridiag_${rank}$d_${s1}$(A, x, y, alpha, beta, op)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
        ${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
        ${t1}$, intent(in), optional :: alpha
        ${t1}$, intent(in), optional :: beta
        character(1), intent(in), optional :: op

        ! Internal variables.
        ${t1}$ :: alpha_, beta_
        integer(ilp) :: n, nrhs, ldx, ldy
        character(1) :: op_

        type(linalg_state_type) :: err0
        #:if t1.startswith('real')
        logical :: is_alpha_special, is_beta_special
        #:endif

        ${t1}$, pointer :: xmat(:, :), ymat(:, :)

        if(present(op)) then
            if(.not.(op == "N" .or. op == "T" .or. op == "C" .or. op == "H")) then
                err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Invalid matrix operation; expected 'N', 'T', 'C' or 'H'.")
                call linalg_error_handling(err0)
            end if
        end if

        ! Deal with optional arguments.
        alpha_ = optval(alpha, one_${s1}$)
        beta_ = optval(beta, zero_${s1}$)
        op_ = optval(op, "N")
        if (op_ == "H") op_ = "C"

        #:if t1.startswith('real')
        is_alpha_special = (alpha_ ==  1.0_${k1}$  .or. alpha_ ==  0.0_${k1}$  .or. alpha_ == -1.0_${k1}$)
        is_beta_special  = (beta_  ==  1.0_${k1}$  .or. beta_  ==  0.0_${k1}$  .or. beta_  == -1.0_${k1}$)
        #:endif

        ! Prepare Lapack arguments.
        n = A%n
        ldx = n
        ldy = n
        nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}#

        ! Pointer trick.
        xmat(1:n, 1:nrhs) => x
        ymat(1:n, 1:nrhs) => y
        #:if t1.startswith('complex')
        call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        #:else
        if(is_alpha_special .and. is_beta_special) then
            call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        else
            call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        end if
        #:endif
    end subroutine
    #:endfor
    #:endfor

    !-------------------------------------
    !-----                           -----
    !-----     UTILITY FUNCTIONS     -----
    !-----                           -----
    !-------------------------------------

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function tridiagonal_to_dense_${s1}$(A) result(B)
        !! Convert a `tridiagonal` matrix to its dense representation.
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input tridiagonal matrix.
        ${t1}$, allocatable :: B(:, :)
        !! Corresponding dense matrix.

        ! Internal variables.
        integer(ilp) :: i

        associate (n => A%n)
        #:if t1.startswith('complex')
        allocate(B(n, n), source=zero_c${k1}$)
        #:else
        allocate(B(n, n), source=zero_${k1}$)
        #:endif
        
        if (n == 1) then
            B(1, 1) = A%dv(1)
        else
            B(1, 1) = A%dv(1)
            B(1, 2) = A%du(1)
            do concurrent (i=2:n-1)
                B(i, i-1) = A%dl(i-1)
                B(i, i) = A%dv(i)
                B(i, i+1) = A%du(i)
            enddo
            B(n, n-1) = A%dl(n-1)
            B(n, n) = A%dv(n)
        end if
        end associate

    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function transpose_tridiagonal_${s1}$(A) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%du, A%dv, A%dl)
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function hermitian_tridiagonal_${s1}$(A) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(tridiagonal_${s1}$_type) :: B
        #:if t1.startswith("complex")
        B = tridiagonal(conjg(A%du), conjg(A%dv), conjg(A%dl))
        #:else
        B = tridiagonal(A%du, A%dv, A%dl)
        #:endif
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function scalar_multiplication_tridiagonal_${s1}$(alpha, A) result(B)
        ${t1}$, intent(in) :: alpha
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%dl, A%dv, A%du)
        B%dl = alpha*B%dl
        B%dv = alpha*B%dv
        B%du = alpha*B%du
    end function

    pure module function scalar_multiplication_bis_tridiagonal_${s1}$(A, alpha) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in) :: alpha
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%dl, A%dv, A%du)
        B%dl = alpha*B%dl
        B%dv = alpha*B%dv
        B%du = alpha*B%du
    end function
    #:endfor

    #:for k1, t1, s1 in KINDS_TYPES
    pure module function matrix_add_tridiagonal_${s1}$(A, B) result(C)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type), intent(in) :: B
        type(tridiagonal_${s1}$_type) :: C
        
        ! Internal variables.
        type(linalg_state_type) :: err0
        
        if (A%n /= B%n) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "tridiagonal matrices must have the same dimension to be added")
            call linalg_error_handling(err0)
        end if
        
        C = tridiagonal(A%dl, A%dv, A%du)
        C%dl = C%dl + B%dl
        C%dv = C%dv + B%dv
        C%du = C%du + B%du
    end function

    pure module function matrix_sub_tridiagonal_${s1}$(A, B) result(C)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type), intent(in) :: B
        type(tridiagonal_${s1}$_type) :: C
        
        ! Internal variables. 
        type(linalg_state_type) :: err0
        
        if (A%n /= B%n) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "tridiagonal matrices must have the same dimension to be subtracted")
            call linalg_error_handling(err0)
        end if
        
        C = tridiagonal(A%dl, A%dv, A%du)
        C%dl = C%dl - B%dl
        C%dv = C%dv - B%dv
        C%du = C%du - B%du
    end function
    #:endfor

end submodule