#:include "common.fypp"
#:set RANKS = range(1, 2+1)
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
#! define ranks without parentheses
#:def rksfx2(rank)
#{if rank > 0}#${":," + ":," * (rank - 1)}$#{endif}#
#:enddef
submodule (stdlib_sparse_spmv) stdlib_sparse_spmv_csr
contains

    !! spmv_csr
    #:for k1, t1, s1 in (KINDS_TYPES)
    #:for rank in RANKS
    module subroutine spmv_csr_${rank}$d_${s1}$(matrix,vec_x,vec_y,alpha,beta,op)
        type(CSR_${s1}$_type), intent(in) :: matrix
        ${t1}$, intent(in)    :: vec_x${ranksuffix(rank)}$
        ${t1}$, intent(inout) :: vec_y${ranksuffix(rank)}$
        ${t1}$, intent(in), optional :: alpha
        ${t1}$, intent(in), optional :: beta
        character(1), intent(in), optional :: op
        ${t1}$ :: alpha_
        character(1) :: op_
        integer(ilp) :: i, j
        #:if rank == 1
        ${t1}$ :: aux, aux2
        #:else
        ${t1}$ :: aux(size(vec_x,dim=1)), aux2(size(vec_x,dim=1))
        #:endif
        
        op_ = sparse_op_none; if(present(op)) op_ = op
        alpha_ = one_${k1}$
        if(present(alpha)) alpha_ = alpha
        if(present(beta)) then
            vec_y = beta * vec_y
        else 
            vec_y = zero_${s1}$
        endif

        associate( data => matrix%data, col => matrix%col, rowptr => matrix%rowptr, &
            & nnz => matrix%nnz, nrows => matrix%nrows, ncols => matrix%ncols, storage => matrix%storage )
    
            if( storage == sparse_full .and. op_==sparse_op_none ) then
                do i = 1, nrows
                    aux = zero_${k1}$
                    do j = rowptr(i), rowptr(i+1)-1
                        aux = aux + data(j) * vec_x(${rksfx2(rank-1)}$col(j))
                    end do
                    vec_y(${rksfx2(rank-1)}$i) = vec_y(${rksfx2(rank-1)}$i) + alpha_ * aux
                end do

            else if( storage == sparse_full .and. op_==sparse_op_transpose ) then
                do i = 1, nrows
                    aux = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i), rowptr(i+1)-1
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + data(j) * aux
                    end do
                end do
                
            else if( storage == sparse_lower .and. op_/=sparse_op_hermitian )then
                do i = 1 , nrows
                    aux  = zero_${s1}$
                    aux2 = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i), rowptr(i+1)-2
                        aux = aux + data(j) * vec_x(${rksfx2(rank-1)}$col(j))
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + data(j) * aux2
                    end do
                    aux = alpha_ * aux + data(j) * aux2
                    vec_y(${rksfx2(rank-1)}$i) = vec_y(${rksfx2(rank-1)}$i) + aux
                end do

            else if( storage == sparse_upper .and. op_/=sparse_op_hermitian )then
                do i = 1 , nrows
                    aux  = vec_x(${rksfx2(rank-1)}$i) * data(rowptr(i))
                    aux2 = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i)+1, rowptr(i+1)-1
                        aux = aux + data(j) * vec_x(${rksfx2(rank-1)}$col(j))
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + data(j) * aux2
                    end do
                    vec_y(${rksfx2(rank-1)}$i) = vec_y(${rksfx2(rank-1)}$i) + alpha_ * aux
                end do
                
            #:if t1.startswith('complex')
            else if( storage == sparse_full .and. op_==sparse_op_hermitian) then
                do i = 1, nrows
                    aux = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i), rowptr(i+1)-1
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + conjg(data(j)) * aux
                    end do
                end do

            else if( storage == sparse_lower .and. op_==sparse_op_hermitian )then
                do i = 1 , nrows
                    aux  = zero_${s1}$
                    aux2 = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i), rowptr(i+1)-2
                        aux = aux + conjg(data(j)) * vec_x(${rksfx2(rank-1)}$col(j))
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + conjg(data(j)) * aux2
                    end do
                    aux = alpha_ * aux + conjg(data(j)) * aux2
                    vec_y(${rksfx2(rank-1)}$i) = vec_y(${rksfx2(rank-1)}$i) + aux
                end do

            else if( storage == sparse_upper .and. op_==sparse_op_hermitian )then
                do i = 1 , nrows
                    aux  = vec_x(${rksfx2(rank-1)}$i) * conjg(data(rowptr(i)))
                    aux2 = alpha_ * vec_x(${rksfx2(rank-1)}$i)
                    do j = rowptr(i)+1, rowptr(i+1)-1
                        aux = aux + conjg(data(j)) * vec_x(${rksfx2(rank-1)}$col(j))
                        vec_y(${rksfx2(rank-1)}$col(j)) = vec_y(${rksfx2(rank-1)}$col(j)) + conjg(data(j)) * aux2
                    end do
                    vec_y(${rksfx2(rank-1)}$i) = vec_y(${rksfx2(rank-1)}$i) + alpha_ * aux
                end do
            #:endif
            end if
        end associate
    end subroutine
    
    #:endfor
    #:endfor

end submodule stdlib_sparse_spmv_csr