Skip to content

Commit 89a993e

Browse files
committed
add in place transpose spmv for SELLC
1 parent 3dfcecd commit 89a993e

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

src/stdlib_sparse_spmv.fypp

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,18 @@ contains
428428
else
429429
vec_y = zero_${s1}$
430430
endif
431+
431432
associate( data => matrix%data, ia => matrix%rowptr , ja => matrix%col, cs => matrix%chunk_size, &
432433
& nnz => matrix%nnz, nrows => matrix%nrows, ncols => matrix%ncols, storage => matrix%storage )
434+
435+
if( .not.any( ${CHUNKS}$ == cs ) ) then
436+
print *, "error: sellc chunk size not supported."
437+
return
438+
end if
439+
433440
num_chunks = nrows / cs
434441
rm = nrows - num_chunks * cs
435-
if( storage == sparse_full) then
442+
if( storage == sparse_full .and. op_==sparse_op_none ) then
436443

437444
select case(cs)
438445
#:for chunk in CHUNKS
@@ -443,9 +450,6 @@ contains
443450
call chunk_kernel_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
444451
end do
445452
#:endfor
446-
case default
447-
print *, "error: chunk size not supported."
448-
return
449453
end select
450454

451455
! remainder
@@ -455,32 +459,79 @@ contains
455459
rowidx = (i - 1)*cs + 1
456460
call chunk_kernel_remainder(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
457461
end if
462+
463+
else if( storage == sparse_full .and. op_==sparse_op_transpose ) then
458464

465+
select case(cs)
466+
#:for chunk in CHUNKS
467+
case(${chunk}$)
468+
do i = 1, num_chunks
469+
nz = ia(i+1) - ia(i)
470+
rowidx = (i - 1)*${chunk}$ + 1
471+
call chunk_kernel_trans_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
472+
end do
473+
#:endfor
474+
end select
475+
476+
! remainder
477+
if(rm>0)then
478+
i = num_chunks + 1
479+
nz = ia(i+1) - ia(i)
480+
rowidx = (i - 1)*cs + 1
481+
call chunk_kernel_remainder_trans(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
482+
end if
483+
else
484+
print *, "error: sellc format for spmv operation not yet supported."
485+
return
459486
end if
460487
end associate
461488

462489
contains
463490
#:for chunk in CHUNKS
464-
pure subroutine chunk_kernel_${chunk}$(nz,a,ja,x,y)
465-
integer, value :: nz
466-
${t1}$, intent(in) :: a(${chunk}$,nz), x(*)
467-
integer(ilp), intent(in) :: ja(${chunk}$,nz)
491+
pure subroutine chunk_kernel_${chunk}$(n,a,col,x,y)
492+
integer, value :: n
493+
${t1}$, intent(in) :: a(${chunk}$,n), x(*)
494+
integer(ilp), intent(in) :: col(${chunk}$,n)
468495
${t1}$, intent(inout) :: y(${chunk}$)
469496
integer :: j
470-
do j = 1, nz
471-
where(ja(:,j) > 0) y = y + alpha_ * a(:,j) * x(ja(:,j))
497+
do j = 1, n
498+
where(col(:,j) > 0) y = y + alpha_ * a(:,j) * x(col(:,j))
499+
end do
500+
end subroutine
501+
pure subroutine chunk_kernel_trans_${chunk}$(n,a,col,x,y)
502+
integer, value :: n
503+
${t1}$, intent(in) :: a(${chunk}$,n), x(${chunk}$)
504+
integer(ilp), intent(in) :: col(${chunk}$,n)
505+
${t1}$, intent(inout) :: y(*)
506+
integer :: j, k
507+
do j = 1, n
508+
do k = 1, ${chunk}$
509+
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
510+
end do
472511
end do
473512
end subroutine
474513
#:endfor
475514

476-
pure subroutine chunk_kernel_remainder(nz,cs,rm,a,ja,x,y)
477-
integer, value :: nz, cs, rm
478-
${t1}$, intent(in) :: a(cs,nz), x(*)
479-
integer(ilp), intent(in) :: ja(cs,nz)
515+
pure subroutine chunk_kernel_remainder(n,cs,rm,a,col,x,y)
516+
integer, value :: n, cs, rm
517+
${t1}$, intent(in) :: a(cs,n), x(*)
518+
integer(ilp), intent(in) :: col(cs,n)
480519
${t1}$, intent(inout) :: y(rm)
481520
integer :: j
482-
do j = 1, nz
483-
where(ja(1:rm,j) > 0) y = y + alpha_ * a(1:rm,j) * x(ja(1:rm,j))
521+
do j = 1, n
522+
where(col(1:rm,j) > 0) y = y + alpha_ * a(1:rm,j) * x(col(1:rm,j))
523+
end do
524+
end subroutine
525+
pure subroutine chunk_kernel_remainder_trans(n,cs,rm,a,col,x,y)
526+
integer, value :: n, cs, rm
527+
${t1}$, intent(in) :: a(cs,n), x(rm)
528+
integer(ilp), intent(in) :: col(cs,n)
529+
${t1}$, intent(inout) :: y(*)
530+
integer :: j, k
531+
do j = 1, n
532+
do k = 1, rm
533+
if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
534+
end do
484535
end do
485536
end subroutine
486537

test/linalg/test_linalg_sparse.fypp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ contains
208208
type(SELLC_${s1}$_type) :: SELLC
209209
type(CSR_${s1}$_type) :: CSR
210210
${t1}$, allocatable :: vec_x(:)
211-
${t1}$, allocatable :: vec_y(:)
211+
${t1}$, allocatable :: vec_y(:), vec_y2(:)
212212
integer :: i
213213

214214
call CSR%malloc(6,6,17)
@@ -226,11 +226,21 @@ contains
226226

227227
allocate( vec_x(6) , source = 1._wp )
228228
allocate( vec_y(6) , source = 0._wp )
229-
229+
230230
call spmv( SELLC, vec_x, vec_y )
231231

232232
call check(error, all(vec_y == real([6,22,27,23,27,48],kind=wp)) )
233233
if (allocated(error)) return
234+
235+
! Test in-place transpose
236+
vec_x = real( [1,2,3,4,5,6] , kind=wp )
237+
call spmv( CSR, vec_x, vec_y , op = sparse_op_transpose )
238+
allocate( vec_y2(6) , source = 0._wp )
239+
call spmv( SELLC, vec_x, vec_y2 , op = sparse_op_transpose )
240+
241+
call check(error, all(vec_y == vec_y2))
242+
if (allocated(error)) return
243+
234244
end block
235245
#:endfor
236246
end subroutine

0 commit comments

Comments
 (0)