stdlib_linalg_iterative_solvers_cg.fypp Source File


Source Code

#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set MATRIX_TYPES = ["dense", "CSR"]
#:set RANKS = range(1, 2+1)

submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_cg
    use stdlib_kinds
    use stdlib_sparse
    use stdlib_constants
    use stdlib_linalg_iterative_solvers
    use stdlib_optval, only: optval
    implicit none

contains

    #:for k, t, s in R_KINDS_TYPES
    module subroutine stdlib_solve_cg_kernel_${s}$(A,b,x,rtol,atol,maxiter,workspace)
        class(stdlib_linop_${s}$_type), intent(in) :: A
        ${t}$, intent(in) :: b(:), rtol, atol
        ${t}$, intent(inout) :: x(:)
        integer, intent(in) :: maxiter
        type(stdlib_solver_workspace_${s}$_type), intent(inout) :: workspace
        !-------------------------
        integer :: iter
        ${t}$ :: norm_sq, norm_sq_old, norm_sq0
        ${t}$ :: alpha, beta, tolsq
        !-------------------------
        iter = 0
        associate(  P  => workspace%tmp(:,1), &
                    R  => workspace%tmp(:,2), &
                    Ap => workspace%tmp(:,3))
            
            norm_sq0 = A%inner_product(B, B)
            if(associated(workspace%callback)) call workspace%callback(x, norm_sq0, iter)

            R = B
            call A%matvec(X, R, alpha= -one_${s}$, beta=one_${s}$, op='N') ! R = B - A*X
            norm_sq = A%inner_product(R, R)
   
            P = R
            
            tolsq = max(rtol*rtol * norm_sq0, atol*atol)
            beta = zero_${s}$
            if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
            do while( norm_sq >= tolsq .and. iter < maxiter)
                call A%matvec(P,Ap, alpha= one_${s}$, beta=zero_${s}$, op='N') ! Ap = A*P

                alpha = norm_sq / A%inner_product(P, Ap)
                
                X = X + alpha * P
                R = R - alpha * Ap

                norm_sq_old = norm_sq
                norm_sq = A%inner_product(R, R)
                beta = norm_sq / norm_sq_old
                
                P = R + beta * P

                iter = iter + 1

                if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
            end do
        end associate
    end subroutine
    #:endfor

    #:for matrix in MATRIX_TYPES
    #:for k, t, s in R_KINDS_TYPES
    module subroutine stdlib_solve_cg_${matrix}$_${s}$(A,b,x,di,rtol,atol,maxiter,restart,workspace)
        #:if matrix == "dense"
        ${t}$, intent(in) :: A(:,:)
        #:else 
        type(${matrix}$_${s}$_type), intent(in) :: A
        #:endif
        ${t}$, intent(in) :: b(:)
        ${t}$, intent(inout) :: x(:)
        ${t}$, intent(in), optional :: rtol, atol
        logical(int8), intent(in), optional, target  :: di(:)
        integer, intent(in), optional :: maxiter
        logical, intent(in), optional  :: restart
        type(stdlib_solver_workspace_${s}$_type), optional, intent(inout), target :: workspace
        !-------------------------
        type(stdlib_linop_${s}$_type) :: op
        type(stdlib_solver_workspace_${s}$_type), pointer :: workspace_
        integer :: n, maxiter_
        ${t}$ :: rtol_, atol_
        logical :: restart_
        logical(int8), pointer :: di_(:)
        !-------------------------
        n = size(b)
        maxiter_ = optval(x=maxiter, default=n)
        restart_ = optval(x=restart, default=.true.)
        rtol_    = optval(x=rtol,    default=1.e-5_${s}$)
        atol_    = optval(x=atol,    default=epsilon(one_${s}$))

        !-------------------------
        ! internal memory setup
        op%matvec => matvec
        ! op%inner_product => default_dot
        if(present(di))then
            di_ => di
        else 
            allocate(di_(n),source=.false._int8)
        end if
        
        if(present(workspace)) then
            workspace_ => workspace
        else
            allocate( workspace_ )
        end if
        if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,stdlib_size_wksp_cg), source = zero_${s}$ )
        !-------------------------
        ! main call to the solver
        if(restart_) x = zero_${s}$
        x = merge( b, x, di_ ) ! copy dirichlet load conditions encoded in B and indicated by di
        call stdlib_solve_cg_kernel(op,b,x,rtol_,atol_,maxiter_,workspace_)

        !-------------------------
        ! internal memory cleanup
        if(.not.present(di)) deallocate(di_)
        di_ => null()
        
        if(.not.present(workspace)) then
            deallocate( workspace_%tmp )
            deallocate( workspace_ )
        end if
        workspace_ => null()
        contains
        
        subroutine matvec(x,y,alpha,beta,op)
            #:if matrix == "dense"
            use stdlib_linalg_blas, only: gemv
            #:endif 
            ${t}$, intent(in)  :: x(:)
            ${t}$, intent(inout) :: y(:)
            ${t}$, intent(in) :: alpha
            ${t}$, intent(in) :: beta
            character(1), intent(in) :: op
            #:if matrix == "dense"
            call gemv(op,m=size(A,1),n=size(A,2),alpha=alpha,a=A,lda=size(A,1),x=x,incx=1,beta=beta,y=y,incy=1)
            #:else 
            call spmv( A , x, y , alpha, beta , op)
            #:endif
            y = merge( zero_${s}$, y, di_ )
        end subroutine
    end subroutine

    #:endfor
    #:endfor

end submodule stdlib_linalg_iterative_cg