diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index 29eb9878d..671cfee2f 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -160,6 +160,37 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect {!example/linalg/example_outer_product.f90!} ``` +## `kronecker_product` - Computes the Kronecker product of two rank-2 arrays + +### Status + +Experimental + +### Description + +Computes the Kronecker product of two rank-2 arrays + +### Syntax + +`C = [[stdlib_linalg(module):kronecker_product(interface)]](A, B)` + +### Arguments + +`A`: Shall be a rank-2 array with dimensions M1, N1 + +`B`: Shall be a rank-2 array with dimensions M2, N2 + +### Return value + +Returns a rank-2 array equal to `A \otimes B`. The shape of the returned array is `[M1*M2, N1*N2]`. + +### Example + +```fortran +{!example/linalg/example_kronecker_product.f90!} +``` + + ## `cross_product` - Computes the cross product of two vectors ### Status diff --git a/example/linalg/example_kronecker_product.f90 b/example/linalg/example_kronecker_product.f90 new file mode 100644 index 000000000..98bab0eca --- /dev/null +++ b/example/linalg/example_kronecker_product.f90 @@ -0,0 +1,26 @@ +program example_kronecker_product + use stdlib_linalg, only: kronecker_product + implicit none + integer, parameter :: m1 = 1, n1 = 2, m2 = 2, n2 = 3 + integer :: i, j + real :: A(m1, n1), B(m2,n2) + real, allocatable :: C(:,:) + + do j = 1, n1 + do i = 1, m1 + A(i,j) = i*j ! A = [1, 2] + end do + end do + + do j = 1, n2 + do i = 1, m2 ! B = [ 1, 2, 3 ] + B(i,j) = i*j ! [ 2, 4, 6 ] + end do + end do + + C = kronecker_product(A, B) + ! C = [ a(1,1) * B(:,:) | a(1,2) * B(:,:) ] + ! or in other words, + ! C = [ 1.00 2.00 3.00 2.00 4.00 6.00 ] + ! [ 2.00 4.00 6.00 4.00 8.00 12.00 ] +end program example_kronecker_product diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8f512af56..ceb1bd2b9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ set(fppFiles stdlib_linalg.fypp stdlib_linalg_diag.fypp stdlib_linalg_outer_product.fypp + stdlib_linalg_kronecker.fypp stdlib_linalg_cross_product.fypp stdlib_optval.fypp stdlib_selection.fypp diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index cfa43d3d9..3ed905c56 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -14,6 +14,7 @@ module stdlib_linalg public :: eye public :: trace public :: outer_product + public :: kronecker_product public :: cross_product public :: is_square public :: is_diagonal @@ -93,6 +94,20 @@ module stdlib_linalg #:endfor end interface outer_product + interface kronecker_product + !! version: experimental + !! + !! Computes the Kronecker product of two arrays of size M1xN1, and of M2xN2, returning an (M1*M2)x(N1*N2) array + !! ([Specification](../page/specs/stdlib_linalg.html# + !! kronecker_product-computes-the-kronecker-product-of-two-matrices)) + #:for k1, t1 in RCI_KINDS_TYPES + pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C) + ${t1}$, intent(in) :: A(:,:), B(:,:) + ${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2)) + end function kronecker_product_${t1[0]}$${k1}$ + #:endfor + end interface kronecker_product + ! Cross product (of two vectors) interface cross_product diff --git a/src/stdlib_linalg_kronecker.fypp b/src/stdlib_linalg_kronecker.fypp new file mode 100644 index 000000000..38895f73e --- /dev/null +++ b/src/stdlib_linalg_kronecker.fypp @@ -0,0 +1,30 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +submodule (stdlib_linalg) stdlib_linalg_kronecker + + implicit none + +contains + + #:for k1, t1 in RCI_KINDS_TYPES + pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C) + ${t1}$, intent(in) :: A(:,:), B(:,:) + ${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2)) + integer :: m1, n1, maxM1, maxN1, maxM2, maxN2 + + maxM1 = size(A, dim=1) + maxN1 = size(A, dim=2) + maxM2 = size(B, dim=1) + maxN2 = size(B, dim=2) + + + do n1 = 1, maxN1 + do m1 = 1, maxM1 + ! We use the Wikipedia convention for ordering of the matrix elements + ! https://en.wikipedia.org/wiki/Kronecker_product + C((m1-1)*maxM2+1:m1*maxM2, (n1-1)*maxN2+1:n1*maxN2) = A(m1, n1) * B(:,:) + end do + end do + end function kronecker_product_${t1[0]}$${k1}$ + #:endfor +end submodule stdlib_linalg_kronecker diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index 2ffd2d7de..6fdf7f17d 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -1,9 +1,10 @@ #:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES module test_linalg use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 - use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product + use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product, kronecker_product implicit none @@ -48,6 +49,9 @@ contains new_unittest("trace_int16", test_trace_int16), & new_unittest("trace_int32", test_trace_int32), & new_unittest("trace_int64", test_trace_int64), & + #:for k1, t1 in RCI_KINDS_TYPES + new_unittest("kronecker_product_${t1[0]}$${k1}$", test_kronecker_product_${t1[0]}$${k1}$), & + #:endfor new_unittest("outer_product_rsp", test_outer_product_rsp), & new_unittest("outer_product_rdp", test_outer_product_rdp), & new_unittest("outer_product_rqp", test_outer_product_rqp), & @@ -554,6 +558,43 @@ contains end subroutine test_trace_int64 + + #:for k1, t1 in RCI_KINDS_TYPES + subroutine test_kronecker_product_${t1[0]}$${k1}$(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + integer, parameter :: m1 = 1, n1 = 2, m2 = 2, n2 = 3 + ${t1}$, dimension(m1*m2,n1*n2), parameter :: expected & + = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4,8,12], [m2*n2, m1*n1])) + ${t1}$, parameter :: tol = 1.e-6 + + ${t1}$ :: A(m1,n1), B(m2,n2) + ${t1}$ :: C(m1*m2,n1*n2), diff(m1*m2,n1*n2) + + integer :: i,j + + do j = 1, n1 + do i = 1, m1 + A(i,j) = i*j ! A = [1, 2] + end do + end do + + do j = 1, n2 + do i = 1, m2 + B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]] + end do + end do + + C = kronecker_product(A,B) + + diff = C - expected + + call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed") + ! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]] + + end subroutine test_kronecker_product_${t1[0]}$${k1}$ + #:endfor + subroutine test_outer_product_rsp(error) !> Error handling type(error_type), allocatable, intent(out) :: error