stdlib_linalg_iterative_solvers_bicgstab.fypp Source File


Source Code

#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set MATRIX_TYPES = ["dense", "CSR"]
#:set RANKS = range(1, 2+1)

submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_bicgstab
    use stdlib_kinds
    use stdlib_sparse
    use stdlib_constants
    use stdlib_linalg_iterative_solvers
    use stdlib_optval, only: optval
    implicit none

contains

    #:for k, t, s in R_KINDS_TYPES
    module subroutine stdlib_solve_bicgstab_kernel_${s}$(A,M,b,x,rtol,atol,maxiter,workspace)
        class(stdlib_linop_${s}$_type), intent(in) :: A
        class(stdlib_linop_${s}$_type), intent(in) :: M 
        ${t}$, intent(in) :: b(:), rtol, atol
        ${t}$, intent(inout) :: x(:)
        integer, intent(in) :: maxiter
        type(stdlib_solver_workspace_${s}$_type), intent(inout) :: workspace
        !-------------------------
        integer :: iter
        ${t}$ :: norm_sq, norm_sq0, tolsq
        ${t}$ :: rho, rho_prev, alpha, beta, omega, rv
        ${t}$, parameter :: rhotol = epsilon(one_${s}$)**2
        ${t}$, parameter :: omegatol = epsilon(one_${s}$)**2
        !-------------------------
        iter = 0
        associate(  R  => workspace%tmp(:,1), &
                    Rt => workspace%tmp(:,2), &
                    P  => workspace%tmp(:,3), &
                    Pt => workspace%tmp(:,4), &
                    V  => workspace%tmp(:,5), &
                    S  => workspace%tmp(:,6), &
                    St => workspace%tmp(:,7), &
                    T  => workspace%tmp(:,8))

        norm_sq = A%inner_product( b, b )
        norm_sq0 = norm_sq
        if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
        
        if ( norm_sq0 > zero_${s}$ ) then
            
            ! Compute initial residual: r = b - A*x
            R = B
            call A%matvec(X, R, alpha= -one_${s}$, beta=one_${s}$, op='N') ! R = B - A*X
            
            ! Choose arbitrary Rt (often Rt = r0)
            Rt = R
            
            tolsq = max(rtol*rtol * norm_sq0, atol*atol)
            
            rho_prev = one_${s}$
            alpha = one_${s}$
            omega = one_${s}$
            
            do while ( (iter < maxiter) .AND. (norm_sq >= tolsq) )

                rho = A%inner_product( Rt, R )
                
                ! Check for rho breakdown
                if (abs(rho) < rhotol) exit
                
                if (iter > 0) then
                    ! Check for omega breakdown
                    if (abs(omega) < omegatol) exit
                    
                    beta = (rho / rho_prev) * (alpha / omega)
                    P = R + beta * (P - omega * V)
                else
                    P = R
                end if
                
                ! Preconditioned BiCGSTAB step
                call M%matvec(P, Pt, alpha=one_${s}$, beta=zero_${s}$, op='N') ! Pt = M^{-1}*P
                call A%matvec(Pt, V, alpha=one_${s}$, beta=zero_${s}$, op='N') ! V = A*Pt
                
                rv = A%inner_product( Rt, V )
                if (abs(rv) < epsilon(one_${s}$)) exit ! rv breakdown
                
                alpha = rho / rv
                
                ! Update residual: s = r - alpha*v
                S = R - alpha * V
                
                ! Check if s is small enough
                norm_sq = A%inner_product( S, S )
                if (norm_sq < tolsq) then
                    X = X + alpha * Pt
                    exit
                end if
                
                ! Preconditioned update for t = A * M^{-1} * s
                call M%matvec(S, St, alpha=one_${s}$, beta=zero_${s}$, op='N') ! St = M^{-1}*S
                call A%matvec(St, T, alpha=one_${s}$, beta=zero_${s}$, op='N') ! T = A*St
                
                ! Compute omega
                omega = A%inner_product( T, S ) / A%inner_product( T, T )
                
                ! Update solution and residual
                X = X + alpha * Pt + omega * St
                R = S - omega * T
                
                norm_sq = A%inner_product( R, R )
                rho_prev = rho
                iter = iter + 1
                if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
            end do
        end if
        end associate
    end subroutine
    #:endfor

    #:for matrix in MATRIX_TYPES
    #:for k, t, s in R_KINDS_TYPES
    module subroutine stdlib_solve_bicgstab_${matrix}$_${s}$(A,b,x,di,rtol,atol,maxiter,restart,precond,M,workspace)
        #:if matrix == "dense"
        use stdlib_linalg, only: diag
        ${t}$, intent(in) :: A(:,:)
        #:else 
        type(${matrix}$_${s}$_type), intent(in) :: A
        #:endif
        ${t}$, intent(in) :: b(:)
        ${t}$, intent(inout) :: x(:)
        ${t}$, intent(in), optional :: rtol, atol
        logical(int8), intent(in), optional, target  :: di(:)
        integer, intent(in), optional  :: maxiter
        logical, intent(in), optional  :: restart
        integer, intent(in), optional  :: precond 
        class(stdlib_linop_${s}$_type), optional , intent(in), target :: M 
        type(stdlib_solver_workspace_${s}$_type), optional, intent(inout), target :: workspace
        !-------------------------
        type(stdlib_linop_${s}$_type) :: op
        type(stdlib_linop_${s}$_type), pointer :: M_ => null()
        type(stdlib_solver_workspace_${s}$_type), pointer :: workspace_
        integer :: n, maxiter_
        ${t}$ :: rtol_, atol_
        logical :: restart_
        logical(int8), pointer :: di_(:)
        !-------------------------
        ! working data for preconditioner
        integer :: precond_
        ${t}$, allocatable :: diagonal(:)

        !-------------------------
        n = size(b)
        maxiter_ = optval(x=maxiter, default=n)
        restart_ = optval(x=restart, default=.true.)
        rtol_    = optval(x=rtol,    default=1.e-5_${s}$)
        atol_    = optval(x=atol,    default=epsilon(one_${s}$))
        precond_ = optval(x=precond, default=pc_none)
        !-------------------------
        ! internal memory setup
        ! preconditioner
        if(present(M)) then
            M_ => M
        else 
            allocate( M_ )
            allocate(diagonal(n),source=zero_${s}$)

            select case(precond_)
            case(pc_jacobi)
                #:if matrix == "dense"
                diagonal = diag(A)
                #:else 
                call diag(A,diagonal)
                #:endif
                M_%matvec => precond_jacobi
            case default
                M_%matvec => precond_none
            end select
            where(abs(diagonal)>epsilon(zero_${s}$)) diagonal = one_${s}$/diagonal
        end if
        ! matvec for the operator
        op%matvec => matvec
        
        ! direchlet boundary conditions mask
        if(present(di))then
            di_ => di
        else 
            allocate(di_(n),source=.false._int8)
        end if
        
        ! workspace for the solver
        if(present(workspace)) then
            workspace_ => workspace
        else
            allocate( workspace_ )
        end if
        if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,stdlib_size_wksp_bicgstab) , source = zero_${s}$ )
        !-------------------------
        ! main call to the solver
        if(restart_) x = zero_${s}$
        x = merge( b, x, di_ ) ! copy dirichlet load conditions encoded in B and indicated by di
        call stdlib_solve_bicgstab_kernel(op,M_,b,x,rtol_,atol_,maxiter_,workspace_)

        !-------------------------
        ! internal memory cleanup
        if(.not.present(di)) deallocate(di_)
        di_ => null()
        
        if(.not.present(workspace)) then
            deallocate( workspace_%tmp )
            deallocate( workspace_ )
        end if
        M_ => null()
        workspace_ => null()
        contains
        
        subroutine matvec(x,y,alpha,beta,op)
            #:if matrix == "dense"
            use stdlib_linalg_blas, only: gemv
            #:endif 
            ${t}$, intent(in)  :: x(:)
            ${t}$, intent(inout) :: y(:)
            ${t}$, intent(in) :: alpha
            ${t}$, intent(in) :: beta
            character(1), intent(in) :: op
            #:if matrix == "dense"
            call gemv(op,m=size(A,1),n=size(A,2),alpha=alpha,a=A,lda=size(A,1),x=x,incx=1,beta=beta,y=y,incy=1)
            #:else 
            call spmv( A , x, y , alpha, beta , op)
            #:endif
            y = merge( zero_${s}$, y, di_ )
        end subroutine

        subroutine precond_none(x,y,alpha,beta,op)
            ${t}$, intent(in)  :: x(:)
            ${t}$, intent(inout) :: y(:)
            ${t}$, intent(in) :: alpha
            ${t}$, intent(in) :: beta
            character(1), intent(in) :: op
            y = merge( zero_${s}$, x, di_ )
        end subroutine
        subroutine precond_jacobi(x,y,alpha,beta,op)
            ${t}$, intent(in)  :: x(:)
            ${t}$, intent(inout) :: y(:)
            ${t}$, intent(in) :: alpha
            ${t}$, intent(in) :: beta
            character(1), intent(in) :: op
            y = merge( zero_${s}$, diagonal * x, di_ ) ! inverted diagonal
        end subroutine
    end subroutine

    #:endfor
    #:endfor

end submodule stdlib_linalg_iterative_bicgstab