#:include "common.fypp" #:set RANKS = range(1, 2+1) #:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) #:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) #:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES submodule (stdlib_specialmatrices) tridiagonal_matrices use stdlib_linalg_lapack, only: lagtm character(len=*), parameter :: this = "tridiagonal matrices" contains !-------------------------------- !----- ----- !----- CONSTRUCTORS ----- !----- ----- !-------------------------------- #:for k1, t1, s1 in (KINDS_TYPES) pure module function initialize_tridiagonal_pure_${s1}$(dl, dv, du) result(A) !! Construct a `tridiagonal` matrix from the rank-1 arrays !! `dl`, `dv` and `du`. ${t1}$, intent(in) :: dl(:), dv(:), du(:) !! tridiagonal matrix elements. type(tridiagonal_${s1}$_type) :: A !! Corresponding tridiagonal matrix. ! Internal variables. integer(ilp) :: n type(linalg_state_type) :: err0 ! Sanity check. n = size(dv, kind=ilp) if (n <= 0) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".") call linalg_error_handling(err0) endif if (size(dl, kind=ilp) /= n-1) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.") call linalg_error_handling(err0) endif if (size(du, kind=ilp) /= n-1) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.") call linalg_error_handling(err0) endif ! Description of the matrix. A%n = n ! Matrix elements. A%dl = dl ; A%dv = dv ; A%du = du end function pure module function initialize_constant_tridiagonal_pure_${s1}$(dl, dv, du, n) result(A) !! Construct a `tridiagonal` matrix with constant elements. ${t1}$, intent(in) :: dl, dv, du !! tridiagonal matrix elements. integer(ilp), intent(in) :: n !! Matrix dimension. type(tridiagonal_${s1}$_type) :: A !! Corresponding tridiagonal matrix. ! Internal variables. integer(ilp) :: i type(linalg_state_type) :: err0 ! Description of the matrix. A%n = n if (n <= 0) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".") call linalg_error_handling(err0) endif ! Matrix elements. A%dl = [(dl, i = 1, n-1)] A%dv = [(dv, i = 1, n)] A%du = [(du, i = 1, n-1)] end function module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A) !! Construct a `tridiagonal` matrix from the rank-1 arrays !! `dl`, `dv` and `du`. ${t1}$, intent(in) :: dl(:), dv(:), du(:) !! tridiagonal matrix elements. type(linalg_state_type), intent(out) :: err !! Error handling. type(tridiagonal_${s1}$_type) :: A !! Corresponding tridiagonal matrix. ! Internal variables. integer(ilp) :: n type(linalg_state_type) :: err0 ! Sanity check. n = size(dv, kind=ilp) if (n <= 0) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".") call linalg_error_handling(err0, err) endif if (size(dl, kind=ilp) /= n-1) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.") call linalg_error_handling(err0, err) endif if (size(du, kind=ilp) /= n-1) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.") call linalg_error_handling(err0, err) endif ! Description of the matrix. A%n = n ! Matrix elements. A%dl = dl ; A%dv = dv ; A%du = du end function module function initialize_constant_tridiagonal_impure_${s1}$(dl, dv, du, n, err) result(A) !! Construct a `tridiagonal` matrix with constant elements. ${t1}$, intent(in) :: dl, dv, du !! tridiagonal matrix elements. integer(ilp), intent(in) :: n !! Matrix dimension. type(linalg_state_type), intent(out) :: err !! Error handling type(tridiagonal_${s1}$_type) :: A !! Corresponding tridiagonal matrix. ! Internal variables. integer(ilp) :: i type(linalg_state_type) :: err0 ! Description of the matrix. A%n = n if (n <= 0) then err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".") call linalg_error_handling(err0, err) endif ! Matrix elements. A%dl = [(dl, i = 1, n-1)] A%dv = [(dv, i = 1, n)] A%du = [(du, i = 1, n-1)] end function #:endfor !----------------------------------------- !----- ----- !----- MATRIX-VECTOR PRODUCT ----- !----- ----- !----------------------------------------- !! spmv_tridiag #:for k1, t1, s1 in (KINDS_TYPES) #:for rank in RANKS module subroutine spmv_tridiag_${rank}$d_${s1}$(A, x, y, alpha, beta, op) type(tridiagonal_${s1}$_type), intent(in) :: A ${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$ ${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$ real(${k1}$), intent(in), optional :: alpha real(${k1}$), intent(in), optional :: beta character(1), intent(in), optional :: op ! Internal variables. real(${k1}$) :: alpha_, beta_ integer(ilp) :: n, nrhs, ldx, ldy character(1) :: op_ #:if rank == 1 ${t1}$, pointer :: xmat(:, :), ymat(:, :) #:endif ! Deal with optional arguments. alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha beta_ = 0.0_${k1}$ ; if (present(beta)) beta_ = beta op_ = "N" ; if (present(op)) op_ = op ! Prepare Lapack arguments. n = A%n ; ldx = n ; ldy = n ; y = 0.0_${k1}$ nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}# #:if rank == 1 ! Pointer trick. xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy) #:else call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy) #:endif end subroutine #:endfor #:endfor !------------------------------------- !----- ----- !----- UTILITY FUNCTIONS ----- !----- ----- !------------------------------------- #:for k1, t1, s1 in (KINDS_TYPES) pure module function tridiagonal_to_dense_${s1}$(A) result(B) !! Convert a `tridiagonal` matrix to its dense representation. type(tridiagonal_${s1}$_type), intent(in) :: A !! Input tridiagonal matrix. ${t1}$, allocatable :: B(:, :) !! Corresponding dense matrix. ! Internal variables. integer(ilp) :: i associate (n => A%n) #:if t1.startswith('complex') allocate(B(n, n), source=zero_c${k1}$) #:else allocate(B(n, n), source=zero_${k1}$) #:endif B(1, 1) = A%dv(1) ; B(1, 2) = A%du(1) do concurrent (i=2:n-1) B(i, i-1) = A%dl(i-1) B(i, i) = A%dv(i) B(i, i+1) = A%du(i) enddo B(n, n-1) = A%dl(n-1) ; B(n, n) = A%dv(n) end associate end function #:endfor #:for k1, t1, s1 in (KINDS_TYPES) pure module function transpose_tridiagonal_${s1}$(A) result(B) type(tridiagonal_${s1}$_type), intent(in) :: A !! Input matrix. type(tridiagonal_${s1}$_type) :: B B = tridiagonal(A%du, A%dv, A%dl) end function #:endfor #:for k1, t1, s1 in (KINDS_TYPES) pure module function hermitian_tridiagonal_${s1}$(A) result(B) type(tridiagonal_${s1}$_type), intent(in) :: A !! Input matrix. type(tridiagonal_${s1}$_type) :: B #:if t1.startswith("complex") B = tridiagonal(conjg(A%du), conjg(A%dv), conjg(A%dl)) #:else B = tridiagonal(A%du, A%dv, A%dl) #:endif end function #:endfor #:for k1, t1, s1 in (KINDS_TYPES) pure module function scalar_multiplication_tridiagonal_${s1}$(alpha, A) result(B) ${t1}$, intent(in) :: alpha type(tridiagonal_${s1}$_type), intent(in) :: A type(tridiagonal_${s1}$_type) :: B B = tridiagonal(A%dl, A%dv, A%du) B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du end function pure module function scalar_multiplication_bis_tridiagonal_${s1}$(A, alpha) result(B) type(tridiagonal_${s1}$_type), intent(in) :: A ${t1}$, intent(in) :: alpha type(tridiagonal_${s1}$_type) :: B B = tridiagonal(A%dl, A%dv, A%du) B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du end function #:endfor #:for k1, t1, s1 in (KINDS_TYPES) pure module function matrix_add_tridiagonal_${s1}$(A, B) result(C) type(tridiagonal_${s1}$_type), intent(in) :: A type(tridiagonal_${s1}$_type), intent(in) :: B type(tridiagonal_${s1}$_type) :: C C = tridiagonal(A%dl, A%dv, A%du) C%dl = C%dl + B%dl; C%dv = C%dv + B%dv; C%du = C%du + B%du end function pure module function matrix_sub_tridiagonal_${s1}$(A, B) result(C) type(tridiagonal_${s1}$_type), intent(in) :: A type(tridiagonal_${s1}$_type), intent(in) :: B type(tridiagonal_${s1}$_type) :: C C = tridiagonal(A%dl, A%dv, A%du) C%dl = C%dl - B%dl; C%dv = C%dv - B%dv; C%du = C%du - B%du end function #:endfor end submodule