stdlib_linalg_solve.fypp Source File


Source Code

#:include "common.fypp"
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
#:set RHS_SUFFIX = ["one","many"]
#:set RHS_SYMBOL = [ranksuffix(r) for r in [1,2]]
#:set RHS_EMPTY  = [emptyranksuffix(r) for r in [1,2]]
#:set ALL_RHS    = list(zip(RHS_SYMBOL,RHS_SUFFIX,RHS_EMPTY))
submodule (stdlib_linalg) stdlib_linalg_solve
!! Solve linear system Ax=b
     use stdlib_linalg_constants
     use stdlib_linalg_lapack, only: gesv, potrs, posv
     use stdlib_linalg_lapack_aux, only: handle_gesv_info, handle_potrs_info, handle_posv_info
     use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
          LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR     
     implicit none
     
     character(*), parameter :: this = 'solve'

     contains
     

     #:for nd,ndsuf,nde in ALL_RHS
     #:for rk,rt,ri in RC_KINDS_TYPES
     ! Compute the solution to a real system of linear equations A * X = B
     module function stdlib_linalg_${ri}$_solve_${ndsuf}$(a,b,overwrite_a,err) result(x)
         !> Input matrix a[n,n]
         ${rt}$, intent(inout), target :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> [optional] Can A data be overwritten and destroyed?
         logical(lk), optional, intent(in) :: overwrite_a
         !> [optional] state return flag. On error if not requested, the code will stop
         type(linalg_state_type), intent(out) :: err
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, allocatable, target :: x${nd}$
                  
         ! Initialize solution shape from the rhs array
         allocate(x,mold=b)         
         
         call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,overwrite_a=overwrite_a,err=err)
            
     end function stdlib_linalg_${ri}$_solve_${ndsuf}$

     !> Compute the solution to a real system of linear equations A * X = B (pure interface)
     pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
         !> Input matrix a[n,n]
         ${rt}$, intent(in) :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, allocatable, target :: x${nd}$
         
         ! Local variables
         ${rt}$, allocatable :: amat(:,:)
         
         ! Copy `a` so it can be intent(in)
         allocate(amat,source=a)
         
         ! Initialize solution shape from the rhs array
         allocate(x,mold=b)         
         
         call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(amat,b,x,overwrite_a=.true.)

     end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
     
     !> Compute the solution to a real system of linear equations A * X = B (pure interface)
     pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
         !> Input matrix a[n,n]
         ${rt}$, intent(inout), target :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, intent(inout), contiguous, target :: x${nd}$
         !> [optional] Storage array for the diagonal pivot indices
         integer(ilp), optional, intent(inout), target :: pivot(:)
         !> [optional] Can A data be overwritten and destroyed?
         logical(lk), optional, intent(in) :: overwrite_a
         !> [optional] state return flag. On error if not requested, the code will stop
         type(linalg_state_type), optional, intent(out) :: err
         
         ! Local variables
         type(linalg_state_type) :: err0
         integer(ilp) :: lda,n,ldb,ldx,nrhsx,nrhs,info,npiv
         integer(ilp), pointer :: ipiv(:)
         logical(lk) :: copy_a
         ${rt}$, pointer :: xmat(:,:),amat(:,:)         

         ! Problem sizes
         lda   = size(a,1,kind=ilp)
         n     = size(a,2,kind=ilp)
         ldb   = size(b,1,kind=ilp)
         nrhs  = size(b  ,kind=ilp)/ldb
         ldx   = size(x,1,kind=ilp)
         nrhsx = size(x  ,kind=ilp)/ldx

         ! Has a pre-allocated pivots storage array been provided? 
         if (present(pivot)) then 
            ipiv => pivot
         else
            allocate(ipiv(n))
         endif
         npiv  = size(ipiv,kind=ilp)

         ! Can A be overwritten? By default, do not overwrite
         if (present(overwrite_a)) then
            copy_a = .not.overwrite_a
         else
            copy_a = .true._lk
         endif

         if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs .or. npiv/=n) then
            err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
                                                             'b=',[ldb,nrhs],' x=',[ldx,nrhsx], &
                                                             'pivot=',n)
            call linalg_error_handling(err0,err)
            return
         end if

         ! Initialize a matrix temporary
         if (copy_a) then
            allocate(amat(lda,n),source=a)
         else
            amat => a
         endif

         ! Initialize solution with the rhs
         x = b
         xmat(1:n,1:nrhs) => x

         ! Solve system
         call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)

         ! Process output
         call handle_gesv_info(this,info,lda,n,nrhs,err0)

         if (copy_a) deallocate(amat)
         if (.not.present(pivot)) deallocate(ipiv)

         ! Process output and return
         call linalg_error_handling(err0,err)

     end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$     
     
    #:endfor
    #:endfor

    !---------------------------------------------------------------------------
    !> solve_chol: One-shot factorize + solve (POSV)
    !---------------------------------------------------------------------------
    
    #:for nd,ndsuf,nde in ALL_RHS
    #:for rk,rt,ri in RC_KINDS_TYPES
    !> Factorize and solve A*x = b in one call (uses LAPACK POSV)
    pure module subroutine stdlib_linalg_${ri}$_solve_chol_${ndsuf}$(a,b,x,lower,overwrite_a,err)
        !> Input SPD matrix a[n,n]
        ${rt}$, intent(inout), target :: a(:,:)
        !> Right hand side vector or array, b[n] or b[n,nrhs]
        ${rt}$, intent(in) :: b${nd}$
        !> Result array/matrix x[n] or x[n,nrhs]
        ${rt}$, intent(inout), contiguous, target :: x${nd}$
        !> [optional] Use lower triangular factorization? Default = .true.
        logical(lk), optional, intent(in) :: lower
        !> [optional] Can A data be overwritten and destroyed? Default = .false.
        logical(lk), optional, intent(in) :: overwrite_a
        !> [optional] State return flag. On error if not requested, the code will stop
        type(linalg_state_type), optional, intent(out) :: err
        
        ! Local variables
        type(linalg_state_type) :: err0
        integer(ilp) :: lda,n,ldb,ldx,nrhs,nrhsx,info
        logical(lk) :: lower_,copy_a
        character :: uplo
        ${rt}$, pointer :: xmat(:,:),amat(:,:)

        ! Problem sizes
        lda   = size(a,1,kind=ilp)
        n     = size(a,2,kind=ilp)
        ldb   = size(b,1,kind=ilp)
        nrhs  = size(b,kind=ilp)/ldb
        ldx   = size(x,1,kind=ilp)
        nrhsx = size(x,kind=ilp)/ldx
        
        ! Default: use lower triangular
        lower_ = optval(lower, .true._lk)
        uplo = merge('L','U',lower_)
        
        ! Can A be overwritten? By default, do not overwrite
        copy_a = .not. optval(overwrite_a, .false._lk)
        
        ! Validate dimensions
        if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs) then
            err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
                                                             'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
            call linalg_error_handling(err0,err)
            return
        end if
        
        ! Initialize a matrix temporary
        if (copy_a) then
           allocate(amat(lda,n),source=a)
        else
           amat => a
        endif
        
        ! Copy RHS to solution array (POSV overwrites with solution)
        x = b
        
        ! Create 2D pointer for LAPACK call
        xmat(1:n,1:nrhs) => x
        
        ! Factorize AND solve using LAPACK POSV
        call posv(uplo,n,nrhs,amat,lda,xmat,n,info)
        
        ! Handle errors using standard handler
        call handle_posv_info(this,info,uplo,n,nrhs,lda,n,err0)
        
        if (copy_a) deallocate(amat)
        
        ! Process output and return
        call linalg_error_handling(err0,err)

    end subroutine stdlib_linalg_${ri}$_solve_chol_${ndsuf}$
    
    #:endfor
    #:endfor

    !---------------------------------------------------------------------------
    !> Private driver: Solve using pre-computed Cholesky factor (POTRS)
    !> Not exported - used internally by solve_lower_chol and solve_upper_chol
    !---------------------------------------------------------------------------
    
    #:for nd,ndsuf,nde in ALL_RHS
    #:for rk,rt,ri in RC_KINDS_TYPES
    !> Low-level driver for solving A*x = b using pre-computed Cholesky factor
    pure subroutine solve_chol_${ri}$_${ndsuf}$_driver(a,b,x,uplo,err)
        !> Cholesky factor (L or U)[n,n] from cholesky(...)
        ${rt}$, intent(in) :: a(:,:)
        !> Right hand side vector or array, b[n] or b[n,nrhs]
        ${rt}$, intent(in) :: b${nd}$
        !> Result array/matrix x[n] or x[n,nrhs]
        ${rt}$, intent(inout), contiguous, target :: x${nd}$
        !> Triangle selector: 'L' for lower, 'U' for upper
        character, intent(in) :: uplo
        !> [optional] State return flag. On error if not requested, the code will stop
        type(linalg_state_type), optional, intent(out) :: err
        
        ! Local variables
        type(linalg_state_type) :: err0
        integer(ilp) :: lda,n,ldb,ldx,nrhs,nrhsx,info
        ${rt}$, pointer :: xmat(:,:)

        ! Problem sizes
        lda   = size(a,1,kind=ilp)
        n     = size(a,2,kind=ilp)
        ldb   = size(b,1,kind=ilp)
        nrhs  = size(b,kind=ilp)/ldb
        ldx   = size(x,1,kind=ilp)
        nrhsx = size(x,kind=ilp)/ldx
        
        ! Validate dimensions
        if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs) then
            err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
                                                             'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
            call linalg_error_handling(err0,err)
            return
        end if
        
        ! Copy RHS to solution array (POTRS overwrites with solution)
        x = b
        
        ! Create 2D pointer for LAPACK call
        xmat(1:n,1:nrhs) => x
        
        ! Solve the system using LAPACK POTRS
        call potrs(uplo,n,nrhs,a,lda,xmat,n,info)
        
        ! Handle errors using standard handler
        call handle_potrs_info(this,info,uplo,n,nrhs,lda,n,err0)
        
        ! Process output and return
        call linalg_error_handling(err0,err)

    end subroutine solve_chol_${ri}$_${ndsuf}$_driver
    
    #:endfor
    #:endfor

    !---------------------------------------------------------------------------
    !> solve_lower_chol: Solve using PRE-COMPUTED LOWER Cholesky factor (POTRS)
    !---------------------------------------------------------------------------
    
    #:for nd,ndsuf,nde in ALL_RHS
    #:for rk,rt,ri in RC_KINDS_TYPES
    !> Solve the linear system A*x = b using pre-computed lower Cholesky factor
    pure module subroutine stdlib_linalg_${ri}$_solve_lower_chol_${ndsuf}$(l,b,x,err)
        !> Lower Cholesky factor l[n,n] from cholesky(...,lower=.true.)
        ${rt}$, intent(in) :: l(:,:)
        !> Right hand side vector or array, b[n] or b[n,nrhs]
        ${rt}$, intent(in) :: b${nd}$
        !> Result array/matrix x[n] or x[n,nrhs]
        ${rt}$, intent(inout), contiguous, target :: x${nd}$
        !> [optional] State return flag. On error if not requested, the code will stop
        type(linalg_state_type), optional, intent(out) :: err
        
        call solve_chol_${ri}$_${ndsuf}$_driver(l,b,x,'L',err)

    end subroutine stdlib_linalg_${ri}$_solve_lower_chol_${ndsuf}$
    
    #:endfor
    #:endfor

    !---------------------------------------------------------------------------
    !> solve_upper_chol: Solve using PRE-COMPUTED UPPER Cholesky factor (POTRS)
    !---------------------------------------------------------------------------
    
    #:for nd,ndsuf,nde in ALL_RHS
    #:for rk,rt,ri in RC_KINDS_TYPES
    !> Solve the linear system A*x = b using pre-computed upper Cholesky factor
    pure module subroutine stdlib_linalg_${ri}$_solve_upper_chol_${ndsuf}$(u,b,x,err)
        !> Upper Cholesky factor u[n,n] from cholesky(...,lower=.false.)
        ${rt}$, intent(in) :: u(:,:)
        !> Right hand side vector or array, b[n] or b[n,nrhs]
        ${rt}$, intent(in) :: b${nd}$
        !> Result array/matrix x[n] or x[n,nrhs]
        ${rt}$, intent(inout), contiguous, target :: x${nd}$
        !> [optional] State return flag. On error if not requested, the code will stop
        type(linalg_state_type), optional, intent(out) :: err
        
        call solve_chol_${ri}$_${ndsuf}$_driver(u,b,x,'U',err)

    end subroutine stdlib_linalg_${ri}$_solve_upper_chol_${ndsuf}$
    
    #:endfor
    #:endfor

end submodule stdlib_linalg_solve