Skip to content

Commit fed4d73

Browse files
committed
add implementation for 3,4,5 matrices
1 parent f06f556 commit fed4d73

File tree

2 files changed

+122
-2
lines changed

2 files changed

+122
-2
lines changed

src/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ set(fppFiles
1919
stdlib_hash_64bit_spookyv2.fypp
2020
stdlib_intrinsics_dot_product.fypp
2121
stdlib_intrinsics_sum.fypp
22+
stdlib_intrinsics_matmul.fypp
2223
stdlib_intrinsics.fypp
2324
stdlib_io.fypp
2425
stdlib_io_npy.fypp
@@ -32,14 +33,14 @@ set(fppFiles
3233
stdlib_linalg_kronecker.fypp
3334
stdlib_linalg_cross_product.fypp
3435
stdlib_linalg_eigenvalues.fypp
35-
stdlib_linalg_solve.fypp
36+
stdlib_linalg_solve.fypp
3637
stdlib_linalg_determinant.fypp
3738
stdlib_linalg_qr.fypp
3839
stdlib_linalg_inverse.fypp
3940
stdlib_linalg_pinv.fypp
4041
stdlib_linalg_norms.fypp
4142
stdlib_linalg_state.fypp
42-
stdlib_linalg_svd.fypp
43+
stdlib_linalg_svd.fypp
4344
stdlib_linalg_cholesky.fypp
4445
stdlib_linalg_schur.fypp
4546
stdlib_optval.fypp

src/stdlib_intrinsics_matmul.fypp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#:include "common.fypp"
2+
#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS))
3+
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
4+
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
5+
6+
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
7+
implicit none
8+
9+
contains
10+
11+
! Algorithm for the optimal bracketization of matrices
12+
! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2
13+
! Internal use only!
14+
pure function matmul_chain_order(n, p) result(s)
15+
integer, intent(in) :: n, p(:)
16+
integer :: s(1:n-1, 2:n), m(1:n, 1:n), l, i, j, k, q
17+
m(:,:) = 0
18+
s(:,:) = 0
19+
20+
do l = 2, n
21+
do i = 1, n - l + 1
22+
j = i + l - 1
23+
m(i,j) = huge(1)
24+
25+
do k = i, j - 1
26+
q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1)
27+
28+
if (q < m(i, j)) then
29+
m(i,j) = q
30+
s(i,j) = k
31+
end if
32+
end do
33+
end do
34+
end do
35+
end function matmul_chain_order
36+
37+
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
38+
39+
pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d)
40+
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:)
41+
${t}$, allocatable :: d(:,:)
42+
integer :: sa(2), sb(2), sc(2), cost1, cost2
43+
sa = shape(a)
44+
sb = shape(b)
45+
sc = shape(c)
46+
47+
if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then
48+
error stop "stdlib_matmul: Incompatible array shapes"
49+
end if
50+
51+
! computes the cost (number of scalar multiplications required)
52+
! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2)
53+
cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C)
54+
cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC))
55+
56+
if (cost1 < cost2) then
57+
d = matmul(matmul(a, b), c)
58+
else
59+
d = matmul(a, matmul(b, c))
60+
end if
61+
end function stdlib_matmul_${s}$_3
62+
63+
pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e)
64+
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:)
65+
${t}$, allocatable :: e(:,:)
66+
integer :: p(5), i
67+
integer :: s(3,2:4)
68+
69+
p(1) = size(a, 1)
70+
p(2) = size(b, 1)
71+
p(3) = size(c, 1)
72+
p(4) = size(d, 1)
73+
p(5) = size(d, 2)
74+
75+
s = matmul_chain_order(4, p)
76+
77+
select case (s(1,4))
78+
case (1)
79+
e = matmul(a, stdlib_matmul(b, c, d))
80+
case (2)
81+
e = matmul(matmul(a, b), matmul(c, d))
82+
case (3)
83+
e = matmul(stdlib_matmul(a, b ,c), d)
84+
case default
85+
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
86+
end select
87+
end function stdlib_matmul_${s}$_4
88+
89+
pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f)
90+
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:)
91+
${t}$, allocatable :: f(:,:)
92+
integer :: p(6), i
93+
integer :: s(4,2:5)
94+
95+
p(1) = size(a, 1)
96+
p(2) = size(b, 1)
97+
p(3) = size(c, 1)
98+
p(4) = size(d, 1)
99+
p(5) = size(e, 1)
100+
p(6) = size(e, 2)
101+
102+
s = matmul_chain_order(5, p)
103+
104+
select case (s(1,5))
105+
case (1)
106+
f = matmul(a, stdlib_matmul(b, c, d, e))
107+
case (2)
108+
f = matmul(matmul(a, b), stdlib_matmul(c, d, e))
109+
case (3)
110+
f = matmul(stdlib_matmul(a, b ,c), matmul(d, e))
111+
case (4)
112+
f = matmul(stdlib_matmul(a, b, c, d), e)
113+
case default
114+
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
115+
end select
116+
end function stdlib_matmul_${s}$_5
117+
118+
#:endfor
119+
end submodule stdlib_intrinsics_matmul

0 commit comments

Comments
 (0)