stdlib_specialmatrices_tridiagonal.fypp Source File


Source Code

#:include "common.fypp"
#:set RANKS = range(1, 2+1)
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
submodule (stdlib_specialmatrices) tridiagonal_matrices
    use stdlib_linalg_lapack, only: lagtm

    character(len=*), parameter :: this = "tridiagonal matrices"
    contains

    !--------------------------------
    !-----                      -----
    !-----     CONSTRUCTORS     -----
    !-----                      -----
    !--------------------------------

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function initialize_tridiagonal_pure_${s1}$(dl, dv, du) result(A)
        !! Construct a `tridiagonal` matrix from the rank-1 arrays
        !! `dl`, `dv` and `du`.
        ${t1}$, intent(in) :: dl(:), dv(:), du(:)
        !! tridiagonal matrix elements.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.

        ! Internal variables.
        integer(ilp) :: n
        type(linalg_state_type) :: err0

        ! Sanity check.
        n = size(dv, kind=ilp)
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0)
        endif
        if (size(dl, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.")
            call linalg_error_handling(err0)
        endif
        if (size(du, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
            call linalg_error_handling(err0)
        endif

        ! Description of the matrix.
        A%n = n
        ! Matrix elements.
        A%dl = dl ; A%dv = dv ; A%du = du
    end function

    pure module function initialize_constant_tridiagonal_pure_${s1}$(dl, dv, du, n) result(A)
        !! Construct a `tridiagonal` matrix with constant elements.
        ${t1}$, intent(in) :: dl, dv, du
        !! tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.

        ! Internal variables.
        integer(ilp) :: i
        type(linalg_state_type) :: err0

        ! Description of the matrix.
        A%n = n
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0)
        endif
        ! Matrix elements.
        A%dl = [(dl, i = 1, n-1)]
        A%dv = [(dv, i = 1, n)]
        A%du = [(du, i = 1, n-1)]
    end function

    module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
        !! Construct a `tridiagonal` matrix from the rank-1 arrays
        !! `dl`, `dv` and `du`.
        ${t1}$, intent(in) :: dl(:), dv(:), du(:)
        !! tridiagonal matrix elements.
        type(linalg_state_type), intent(out) :: err
        !! Error handling.
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.

        ! Internal variables.
        integer(ilp) :: n
        type(linalg_state_type) :: err0

        ! Sanity check.
        n = size(dv, kind=ilp)
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif
        if (size(dl, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.")
            call linalg_error_handling(err0, err)
        endif
        if (size(du, kind=ilp) /= n-1) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
            call linalg_error_handling(err0, err)
        endif

        ! Description of the matrix.
        A%n = n
        ! Matrix elements.
        A%dl = dl ; A%dv = dv ; A%du = du
    end function

    module function initialize_constant_tridiagonal_impure_${s1}$(dl, dv, du, n, err) result(A)
        !! Construct a `tridiagonal` matrix with constant elements.
        ${t1}$, intent(in) :: dl, dv, du
        !! tridiagonal matrix elements.
        integer(ilp), intent(in) :: n
        !! Matrix dimension.
        type(linalg_state_type), intent(out) :: err
        !! Error handling
        type(tridiagonal_${s1}$_type) :: A
        !! Corresponding tridiagonal matrix.

        ! Internal variables.
        integer(ilp) :: i
        type(linalg_state_type) :: err0

        ! Description of the matrix.
        A%n = n
        if (n <= 0) then
            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
            call linalg_error_handling(err0, err)
        endif
        ! Matrix elements.
        A%dl = [(dl, i = 1, n-1)]
        A%dv = [(dv, i = 1, n)]
        A%du = [(du, i = 1, n-1)]
    end function
    #:endfor

    !-----------------------------------------
    !-----                               -----
    !-----     MATRIX-VECTOR PRODUCT     -----
    !-----                               -----
    !-----------------------------------------

    !! spmv_tridiag
    #:for k1, t1, s1 in (KINDS_TYPES)
    #:for rank in RANKS
    module subroutine spmv_tridiag_${rank}$d_${s1}$(A, x, y, alpha, beta, op)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
        ${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
        real(${k1}$), intent(in), optional :: alpha
        real(${k1}$), intent(in), optional :: beta
        character(1), intent(in), optional :: op

        ! Internal variables.
        real(${k1}$) :: alpha_, beta_
        integer(ilp) :: n, nrhs, ldx, ldy
        character(1) :: op_
        #:if rank == 1
        ${t1}$, pointer :: xmat(:, :), ymat(:, :)
        #:endif

        ! Deal with optional arguments.
        alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha
        beta_  = 0.0_${k1}$ ; if (present(beta))  beta_  = beta
        op_    = "N"        ; if (present(op))    op_    = op

        ! Prepare Lapack arguments.
        n = A%n ; ldx = n ; ldy = n ; y = 0.0_${k1}$
        nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}#

        #:if rank == 1
        ! Pointer trick.
        xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y
        call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
        #:else
        call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
        #:endif
    end subroutine
    #:endfor
    #:endfor

    !-------------------------------------
    !-----                           -----
    !-----     UTILITY FUNCTIONS     -----
    !-----                           -----
    !-------------------------------------

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function tridiagonal_to_dense_${s1}$(A) result(B)
        !! Convert a `tridiagonal` matrix to its dense representation.
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input tridiagonal matrix.
        ${t1}$, allocatable :: B(:, :)
        !! Corresponding dense matrix.

        ! Internal variables.
        integer(ilp) :: i

        associate (n => A%n)
        #:if t1.startswith('complex')
        allocate(B(n, n), source=zero_c${k1}$)
        #:else
        allocate(B(n, n), source=zero_${k1}$)
        #:endif
        B(1, 1) = A%dv(1) ; B(1, 2) = A%du(1)
        do concurrent (i=2:n-1)
            B(i, i-1) = A%dl(i-1)
            B(i, i) = A%dv(i)
            B(i, i+1) = A%du(i)
        enddo
        B(n, n-1) = A%dl(n-1) ; B(n, n) = A%dv(n)
        end associate
    end function
    #:endfor

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function transpose_tridiagonal_${s1}$(A) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%du, A%dv, A%dl)
    end function
    #:endfor

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function hermitian_tridiagonal_${s1}$(A) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        !! Input matrix.
        type(tridiagonal_${s1}$_type) :: B
        #:if t1.startswith("complex")
        B = tridiagonal(conjg(A%du), conjg(A%dv), conjg(A%dl))
        #:else
        B = tridiagonal(A%du, A%dv, A%dl)
        #:endif
    end function
    #:endfor

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function scalar_multiplication_tridiagonal_${s1}$(alpha, A) result(B)
        ${t1}$, intent(in) :: alpha
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%dl, A%dv, A%du)
        B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du
    end function

    pure module function scalar_multiplication_bis_tridiagonal_${s1}$(A, alpha) result(B)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        ${t1}$, intent(in) :: alpha
        type(tridiagonal_${s1}$_type) :: B
        B = tridiagonal(A%dl, A%dv, A%du)
        B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du
    end function
    #:endfor

    #:for k1, t1, s1 in (KINDS_TYPES)
    pure module function matrix_add_tridiagonal_${s1}$(A, B) result(C)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type), intent(in) :: B
        type(tridiagonal_${s1}$_type) :: C
        C = tridiagonal(A%dl, A%dv, A%du)
        C%dl = C%dl + B%dl; C%dv = C%dv + B%dv; C%du = C%du + B%du
    end function

    pure module function matrix_sub_tridiagonal_${s1}$(A, B) result(C)
        type(tridiagonal_${s1}$_type), intent(in) :: A
        type(tridiagonal_${s1}$_type), intent(in) :: B
        type(tridiagonal_${s1}$_type) :: C
        C = tridiagonal(A%dl, A%dv, A%du)
        C%dl = C%dl - B%dl; C%dv = C%dv - B%dv; C%du = C%du - B%du
    end function
    #:endfor

end submodule