Skip to content

Commit ae02481

Browse files
committed
simplify sellc spmv kernel
1 parent 89a993e commit ae02481

File tree

1 file changed

+29
-54
lines changed

1 file changed

+29
-54
lines changed

src/stdlib_sparse_spmv.fypp

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ contains
418418
character(1), intent(in), optional :: op
419419
${t1}$ :: alpha_
420420
character(1) :: op_
421-
integer(ilp) :: i, nz, rowidx, num_chunks, rm
421+
integer(ilp) :: i, j, k, nz, rowidx, num_chunks, rm
422422

423423
op_ = sparse_op_none; if(present(op)) op_ = op
424424
alpha_ = one_${s1}$
@@ -447,7 +447,12 @@ contains
447447
do i = 1, num_chunks
448448
nz = ia(i+1) - ia(i)
449449
rowidx = (i - 1)*${chunk}$ + 1
450-
call chunk_kernel_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
450+
associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
451+
& x => vec_x, y => vec_y(rowidx:rowidx+${chunk}$-1) )
452+
do j = 1, nz
453+
where(col(:,j) > 0) y = y + alpha_ * mat(:,j) * x(col(:,j))
454+
end do
455+
end associate
451456
end do
452457
#:endfor
453458
end select
@@ -457,7 +462,12 @@ contains
457462
i = num_chunks + 1
458463
nz = ia(i+1) - ia(i)
459464
rowidx = (i - 1)*cs + 1
460-
call chunk_kernel_remainder(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
465+
associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
466+
& x => vec_x, y => vec_y(rowidx:rowidx+rm-1) )
467+
do j = 1, nz
468+
where(col(1:rm,j) > 0) y = y + alpha_ * mat(1:rm,j) * x(col(1:rm,j))
469+
end do
470+
end associate
461471
end if
462472

463473
else if( storage == sparse_full .and. op_==sparse_op_transpose ) then
@@ -468,7 +478,14 @@ contains
468478
do i = 1, num_chunks
469479
nz = ia(i+1) - ia(i)
470480
rowidx = (i - 1)*${chunk}$ + 1
471-
call chunk_kernel_trans_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
481+
associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
482+
& x => vec_x(rowidx:rowidx+${chunk}$-1), y => vec_y )
483+
do j = 1, nz
484+
do k = 1, ${chunk}$
485+
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k)
486+
end do
487+
end do
488+
end associate
472489
end do
473490
#:endfor
474491
end select
@@ -478,63 +495,21 @@ contains
478495
i = num_chunks + 1
479496
nz = ia(i+1) - ia(i)
480497
rowidx = (i - 1)*cs + 1
481-
call chunk_kernel_remainder_trans(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
498+
associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
499+
& x => vec_x(rowidx:rowidx+rm-1), y => vec_y )
500+
do j = 1, nz
501+
do k = 1, rm
502+
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k)
503+
end do
504+
end do
505+
end associate
482506
end if
483507
else
484508
print *, "error: sellc format for spmv operation not yet supported."
485509
return
486510
end if
487511
end associate
488512

489-
contains
490-
#:for chunk in CHUNKS
491-
pure subroutine chunk_kernel_${chunk}$(n,a,col,x,y)
492-
integer, value :: n
493-
${t1}$, intent(in) :: a(${chunk}$,n), x(*)
494-
integer(ilp), intent(in) :: col(${chunk}$,n)
495-
${t1}$, intent(inout) :: y(${chunk}$)
496-
integer :: j
497-
do j = 1, n
498-
where(col(:,j) > 0) y = y + alpha_ * a(:,j) * x(col(:,j))
499-
end do
500-
end subroutine
501-
pure subroutine chunk_kernel_trans_${chunk}$(n,a,col,x,y)
502-
integer, value :: n
503-
${t1}$, intent(in) :: a(${chunk}$,n), x(${chunk}$)
504-
integer(ilp), intent(in) :: col(${chunk}$,n)
505-
${t1}$, intent(inout) :: y(*)
506-
integer :: j, k
507-
do j = 1, n
508-
do k = 1, ${chunk}$
509-
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
510-
end do
511-
end do
512-
end subroutine
513-
#:endfor
514-
515-
pure subroutine chunk_kernel_remainder(n,cs,rm,a,col,x,y)
516-
integer, value :: n, cs, rm
517-
${t1}$, intent(in) :: a(cs,n), x(*)
518-
integer(ilp), intent(in) :: col(cs,n)
519-
${t1}$, intent(inout) :: y(rm)
520-
integer :: j
521-
do j = 1, n
522-
where(col(1:rm,j) > 0) y = y + alpha_ * a(1:rm,j) * x(col(1:rm,j))
523-
end do
524-
end subroutine
525-
pure subroutine chunk_kernel_remainder_trans(n,cs,rm,a,col,x,y)
526-
integer, value :: n, cs, rm
527-
${t1}$, intent(in) :: a(cs,n), x(rm)
528-
integer(ilp), intent(in) :: col(cs,n)
529-
${t1}$, intent(inout) :: y(*)
530-
integer :: j, k
531-
do j = 1, n
532-
do k = 1, rm
533-
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
534-
end do
535-
end do
536-
end subroutine
537-
538513
end subroutine
539514

540515
#:endfor

0 commit comments

Comments
 (0)