|
| 1 | +#:include "common.fypp" |
| 2 | +#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS)) |
| 3 | +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) |
| 4 | +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) |
| 5 | + |
| 6 | +submodule (stdlib_intrinsics) stdlib_intrinsics_matmul |
| 7 | + implicit none |
| 8 | + |
| 9 | +contains |
| 10 | + |
| 11 | + ! Algorithm for the optimal bracketization of matrices |
| 12 | + ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 |
| 13 | + ! Internal use only! |
| 14 | + pure function matmul_chain_order(n, p) result(s) |
| 15 | + integer, intent(in) :: n, p(:) |
| 16 | + integer :: s(1:n-1, 2:n), m(1:n, 1:n), l, i, j, k, q |
| 17 | + m(:,:) = 0 |
| 18 | + s(:,:) = 0 |
| 19 | + |
| 20 | + do l = 2, n |
| 21 | + do i = 1, n - l + 1 |
| 22 | + j = i + l - 1 |
| 23 | + m(i,j) = huge(1) |
| 24 | + |
| 25 | + do k = i, j - 1 |
| 26 | + q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1) |
| 27 | + |
| 28 | + if (q < m(i, j)) then |
| 29 | + m(i,j) = q |
| 30 | + s(i,j) = k |
| 31 | + end if |
| 32 | + end do |
| 33 | + end do |
| 34 | + end do |
| 35 | + end function matmul_chain_order |
| 36 | + |
| 37 | +#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES |
| 38 | + |
| 39 | + pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) |
| 40 | + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) |
| 41 | + ${t}$, allocatable :: d(:,:) |
| 42 | + integer :: sa(2), sb(2), sc(2), cost1, cost2 |
| 43 | + sa = shape(a) |
| 44 | + sb = shape(b) |
| 45 | + sc = shape(c) |
| 46 | + |
| 47 | + if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then |
| 48 | + error stop "stdlib_matmul: Incompatible array shapes" |
| 49 | + end if |
| 50 | + |
| 51 | + ! computes the cost (number of scalar multiplications required) |
| 52 | + ! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2) |
| 53 | + cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C) |
| 54 | + cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC)) |
| 55 | + |
| 56 | + if (cost1 < cost2) then |
| 57 | + d = matmul(matmul(a, b), c) |
| 58 | + else |
| 59 | + d = matmul(a, matmul(b, c)) |
| 60 | + end if |
| 61 | + end function stdlib_matmul_${s}$_3 |
| 62 | + |
| 63 | + pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e) |
| 64 | + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:) |
| 65 | + ${t}$, allocatable :: e(:,:) |
| 66 | + integer :: p(5), i |
| 67 | + integer :: s(3,2:4) |
| 68 | + |
| 69 | + p(1) = size(a, 1) |
| 70 | + p(2) = size(b, 1) |
| 71 | + p(3) = size(c, 1) |
| 72 | + p(4) = size(d, 1) |
| 73 | + p(5) = size(d, 2) |
| 74 | + |
| 75 | + s = matmul_chain_order(4, p) |
| 76 | + |
| 77 | + select case (s(1,4)) |
| 78 | + case (1) |
| 79 | + e = matmul(a, stdlib_matmul(b, c, d)) |
| 80 | + case (2) |
| 81 | + e = matmul(matmul(a, b), matmul(c, d)) |
| 82 | + case (3) |
| 83 | + e = matmul(stdlib_matmul(a, b ,c), d) |
| 84 | + case default |
| 85 | + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" |
| 86 | + end select |
| 87 | + end function stdlib_matmul_${s}$_4 |
| 88 | + |
| 89 | + pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f) |
| 90 | + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:) |
| 91 | + ${t}$, allocatable :: f(:,:) |
| 92 | + integer :: p(6), i |
| 93 | + integer :: s(4,2:5) |
| 94 | + |
| 95 | + p(1) = size(a, 1) |
| 96 | + p(2) = size(b, 1) |
| 97 | + p(3) = size(c, 1) |
| 98 | + p(4) = size(d, 1) |
| 99 | + p(5) = size(e, 1) |
| 100 | + p(6) = size(e, 2) |
| 101 | + |
| 102 | + s = matmul_chain_order(5, p) |
| 103 | + |
| 104 | + select case (s(1,5)) |
| 105 | + case (1) |
| 106 | + f = matmul(a, stdlib_matmul(b, c, d, e)) |
| 107 | + case (2) |
| 108 | + f = matmul(matmul(a, b), stdlib_matmul(c, d, e)) |
| 109 | + case (3) |
| 110 | + f = matmul(stdlib_matmul(a, b ,c), matmul(d, e)) |
| 111 | + case (4) |
| 112 | + f = matmul(stdlib_matmul(a, b, c, d), e) |
| 113 | + case default |
| 114 | + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" |
| 115 | + end select |
| 116 | + end function stdlib_matmul_${s}$_5 |
| 117 | + |
| 118 | +#:endfor |
| 119 | +end submodule stdlib_intrinsics_matmul |
0 commit comments