stdlib_linalg_solve.fypp Source File


Source Code

#:include "common.fypp"
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
#:set RHS_SUFFIX = ["one","many"]
#:set RHS_SYMBOL = [ranksuffix(r) for r in [1,2]]
#:set RHS_EMPTY  = [emptyranksuffix(r) for r in [1,2]]
#:set ALL_RHS    = list(zip(RHS_SYMBOL,RHS_SUFFIX,RHS_EMPTY))
submodule (stdlib_linalg) stdlib_linalg_solve
!! Solve linear system Ax=b
     use stdlib_linalg_constants
     use stdlib_linalg_lapack, only: gesv
     use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
          LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR     
     implicit none(type,external)
     
     character(*), parameter :: this = 'solve'

     contains
     
     elemental subroutine handle_gesv_info(info,lda,n,nrhs,err)
         integer(ilp), intent(in) :: info,lda,n,nrhs
         type(linalg_state_type), intent(out) :: err

         ! Process output
         select case (info)
            case (0)
                ! Success
            case (-1)
                err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size n=',n)
            case (-2)
                err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid rhs size n=',nrhs)
            case (-4)
                err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid matrix size a=',[lda,n])
            case (-7)
                err = linalg_state_type(this,LINALG_ERROR,'invalid matrix size a=',[lda,n])
            case (1:)
                err = linalg_state_type(this,LINALG_ERROR,'singular matrix')
            case default
                err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
         end select
         
     end subroutine handle_gesv_info

     #:for nd,ndsuf,nde in ALL_RHS
     #:for rk,rt,ri in RC_KINDS_TYPES
     ! Compute the solution to a real system of linear equations A * X = B
     module function stdlib_linalg_${ri}$_solve_${ndsuf}$(a,b,overwrite_a,err) result(x)
         !> Input matrix a[n,n]
         ${rt}$, intent(inout), target :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> [optional] Can A data be overwritten and destroyed?
         logical(lk), optional, intent(in) :: overwrite_a
         !> [optional] state return flag. On error if not requested, the code will stop
         type(linalg_state_type), intent(out) :: err
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, allocatable, target :: x${nd}$
                  
         ! Initialize solution shape from the rhs array
         allocate(x,mold=b)         
         
         call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,overwrite_a=overwrite_a,err=err)
            
     end function stdlib_linalg_${ri}$_solve_${ndsuf}$

     !> Compute the solution to a real system of linear equations A * X = B (pure interface)
     pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
         !> Input matrix a[n,n]
         ${rt}$, intent(in) :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, allocatable, target :: x${nd}$
         
         ! Local variables
         ${rt}$, allocatable :: amat(:,:)
         
         ! Copy `a` so it can be intent(in)
         allocate(amat,source=a)
         
         ! Initialize solution shape from the rhs array
         allocate(x,mold=b)         
         
         call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(amat,b,x,overwrite_a=.true.)

     end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
     
     !> Compute the solution to a real system of linear equations A * X = B (pure interface)
     pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
         !> Input matrix a[n,n]
         ${rt}$, intent(inout), target :: a(:,:)
         !> Right hand side vector or array, b[n] or b[n,nrhs]
         ${rt}$, intent(in) :: b${nd}$
         !> Result array/matrix x[n] or x[n,nrhs]
         ${rt}$, intent(inout), contiguous, target :: x${nd}$
         !> [optional] Storage array for the diagonal pivot indices
         integer(ilp), optional, intent(inout), target :: pivot(:)
         !> [optional] Can A data be overwritten and destroyed?
         logical(lk), optional, intent(in) :: overwrite_a
         !> [optional] state return flag. On error if not requested, the code will stop
         type(linalg_state_type), optional, intent(out) :: err
         
         ! Local variables
         type(linalg_state_type) :: err0
         integer(ilp) :: lda,n,ldb,ldx,nrhsx,nrhs,info,npiv
         integer(ilp), pointer :: ipiv(:)
         logical(lk) :: copy_a
         ${rt}$, pointer :: xmat(:,:),amat(:,:)         

         ! Problem sizes
         lda   = size(a,1,kind=ilp)
         n     = size(a,2,kind=ilp)
         ldb   = size(b,1,kind=ilp)
         nrhs  = size(b  ,kind=ilp)/ldb
         ldx   = size(x,1,kind=ilp)
         nrhsx = size(x  ,kind=ilp)/ldx

         ! Has a pre-allocated pivots storage array been provided? 
         if (present(pivot)) then 
            ipiv => pivot
         else
            allocate(ipiv(n))
         endif
         npiv  = size(ipiv,kind=ilp)

         ! Can A be overwritten? By default, do not overwrite
         if (present(overwrite_a)) then
            copy_a = .not.overwrite_a
         else
            copy_a = .true._lk
         endif

         if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs .or. npiv/=n) then
            err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
                                                             'b=',[ldb,nrhs],' x=',[ldx,nrhsx], &
                                                             'pivot=',n)
            call linalg_error_handling(err0,err)
            return
         end if

         ! Initialize a matrix temporary
         if (copy_a) then
            allocate(amat(lda,n),source=a)
         else
            amat => a
         endif

         ! Initialize solution with the rhs
         x = b
         xmat(1:n,1:nrhs) => x

         ! Solve system
         call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)

         ! Process output
         call handle_gesv_info(info,lda,n,nrhs,err0)

         if (copy_a) deallocate(amat)
         if (.not.present(pivot)) deallocate(ipiv)

         ! Process output and return
         call linalg_error_handling(err0,err)

     end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$     
     
     #:endfor
     #:endfor

end submodule stdlib_linalg_solve