#:include "common.fypp" #:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) #:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) #:set MATRIX_TYPES = ["dense", "CSR"] #:set RANKS = range(1, 2+1) submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_cg use stdlib_kinds use stdlib_sparse use stdlib_constants use stdlib_linalg_iterative_solvers use stdlib_optval, only: optval implicit none contains #:for k, t, s in R_KINDS_TYPES module subroutine stdlib_solve_cg_kernel_${s}$(A,b,x,rtol,atol,maxiter,workspace) class(stdlib_linop_${s}$_type), intent(in) :: A ${t}$, intent(in) :: b(:), rtol, atol ${t}$, intent(inout) :: x(:) integer, intent(in) :: maxiter type(stdlib_solver_workspace_${s}$_type), intent(inout) :: workspace !------------------------- integer :: iter ${t}$ :: norm_sq, norm_sq_old, norm_sq0 ${t}$ :: alpha, beta, tolsq !------------------------- iter = 0 associate( P => workspace%tmp(:,1), & R => workspace%tmp(:,2), & Ap => workspace%tmp(:,3)) norm_sq0 = A%inner_product(B, B) if(associated(workspace%callback)) call workspace%callback(x, norm_sq0, iter) R = B call A%matvec(X, R, alpha= -one_${s}$, beta=one_${s}$, op='N') ! R = B - A*X norm_sq = A%inner_product(R, R) P = R tolsq = max(rtol*rtol * norm_sq0, atol*atol) beta = zero_${s}$ if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter) do while( norm_sq >= tolsq .and. iter < maxiter) call A%matvec(P,Ap, alpha= one_${s}$, beta=zero_${s}$, op='N') ! Ap = A*P alpha = norm_sq / A%inner_product(P, Ap) X = X + alpha * P R = R - alpha * Ap norm_sq_old = norm_sq norm_sq = A%inner_product(R, R) beta = norm_sq / norm_sq_old P = R + beta * P iter = iter + 1 if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter) end do end associate end subroutine #:endfor #:for matrix in MATRIX_TYPES #:for k, t, s in R_KINDS_TYPES module subroutine stdlib_solve_cg_${matrix}$_${s}$(A,b,x,di,rtol,atol,maxiter,restart,workspace) #:if matrix == "dense" ${t}$, intent(in) :: A(:,:) #:else type(${matrix}$_${s}$_type), intent(in) :: A #:endif ${t}$, intent(in) :: b(:) ${t}$, intent(inout) :: x(:) ${t}$, intent(in), optional :: rtol, atol logical(int8), intent(in), optional, target :: di(:) integer, intent(in), optional :: maxiter logical, intent(in), optional :: restart type(stdlib_solver_workspace_${s}$_type), optional, intent(inout), target :: workspace !------------------------- type(stdlib_linop_${s}$_type) :: op type(stdlib_solver_workspace_${s}$_type), pointer :: workspace_ integer :: n, maxiter_ ${t}$ :: rtol_, atol_ logical :: restart_ logical(int8), pointer :: di_(:) !------------------------- n = size(b) maxiter_ = optval(x=maxiter, default=n) restart_ = optval(x=restart, default=.true.) rtol_ = optval(x=rtol, default=1.e-5_${s}$) atol_ = optval(x=atol, default=epsilon(one_${s}$)) !------------------------- ! internal memory setup op%matvec => matvec ! op%inner_product => default_dot if(present(di))then di_ => di else allocate(di_(n),source=.false._int8) end if if(present(workspace)) then workspace_ => workspace else allocate( workspace_ ) end if if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,stdlib_size_wksp_cg), source = zero_${s}$ ) !------------------------- ! main call to the solver if(restart_) x = zero_${s}$ x = merge( b, x, di_ ) ! copy dirichlet load conditions encoded in B and indicated by di call stdlib_solve_cg_kernel(op,b,x,rtol_,atol_,maxiter_,workspace_) !------------------------- ! internal memory cleanup if(.not.present(di)) deallocate(di_) di_ => null() if(.not.present(workspace)) then deallocate( workspace_%tmp ) deallocate( workspace_ ) end if workspace_ => null() contains subroutine matvec(x,y,alpha,beta,op) #:if matrix == "dense" use stdlib_linalg_blas, only: gemv #:endif ${t}$, intent(in) :: x(:) ${t}$, intent(inout) :: y(:) ${t}$, intent(in) :: alpha ${t}$, intent(in) :: beta character(1), intent(in) :: op #:if matrix == "dense" call gemv(op,m=size(A,1),n=size(A,2),alpha=alpha,a=A,lda=size(A,1),x=x,incx=1,beta=beta,y=y,incy=1) #:else call spmv( A , x, y , alpha, beta , op) #:endif y = merge( zero_${s}$, y, di_ ) end subroutine end subroutine #:endfor #:endfor end submodule stdlib_linalg_iterative_cg