Skip to content

Commit a3dc0a7

Browse files
committed
Broadcast input matrices in Gemm
1 parent 89353bd commit a3dc0a7

File tree

2 files changed

+187
-63
lines changed

2 files changed

+187
-63
lines changed

aesara/tensor/blas.py

Lines changed: 116 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
except ImportError:
139139
pass
140140

141-
from functools import reduce
142141
from typing import Tuple, Union
143142

144143
import aesara.scalar
@@ -630,8 +629,10 @@ def c_header_dirs(self, **kwargs):
630629
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
631630
"""
632631

632+
# broadcast_xy = None
633+
633634
check_dims = """
634-
if (Nx[0] != Nz[0])
635+
if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
635636
{
636637
PyErr_Format(PyExc_ValueError,
637638
"Shape mismatch: x has %%ld rows but z has %%ld rows",
@@ -645,7 +646,7 @@ def c_header_dirs(self, **kwargs):
645646
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
646647
%(fail)s;
647648
}
648-
if (Ny[1] != Nz[1])
649+
if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
649650
{
650651
PyErr_Format(PyExc_ValueError,
651652
"Shape mismatch: y has %%ld cols but z has %%ld cols",
@@ -822,14 +823,14 @@ def build_gemm_call(self):
822823
else:
823824
setup_z_Nz_Sz = self.setup_z_Nz_Sz
824825

825-
return reduce(
826-
str.__add__,
826+
return "".join(
827827
(
828828
self.declare_NS,
829829
self.check_xyz_rank2,
830830
setup_z_Nz_Sz,
831831
self.check_xyz_double_or_float,
832832
self.check_ab_double_or_float,
833+
self.broadcast_xy,
833834
self.check_dims,
834835
self.check_strides,
835836
self.encode_strides_in_unit,
@@ -842,8 +843,7 @@ def build_gemm_call(self):
842843
self.case_double_ab_constants,
843844
self.case_double_gemm,
844845
self.end_switch_typenum,
845-
),
846-
"",
846+
)
847847
)
848848

849849
def build_gemm_version(self):
@@ -973,6 +973,11 @@ def perform(self, node, inp, out, params):
973973
z.itemset(z * a + b * np.dot(x, y))
974974
zout[0] = z
975975
else:
976+
# Broadcast Z if needed
977+
if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]):
978+
z = np.broadcast_to(
979+
z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1]))
980+
).copy()
976981
if b == 0.0:
977982
if a == 1.0:
978983
z[:] = np.dot(x, y)
@@ -993,88 +998,135 @@ def perform(self, node, inp, out, params):
993998
zout[0] = z
994999

9951000
def infer_shape(self, fgraph, node, input_shapes):
996-
return [input_shapes[0]]
1001+
z_shape, _, x_shape, y_shape, _ = input_shapes
1002+
return [
1003+
(
1004+
aesara.scalar.scalar_maximum(z_shape[0], x_shape[0]),
1005+
aesara.scalar.scalar_maximum(z_shape[1], y_shape[1]),
1006+
)
1007+
]
9971008

9981009
setup_z_Nz_Sz_inplace = """
999-
if (%(_zout)s != %(_z)s)
1000-
{
1001-
if (%(_zout)s)
1010+
// Needs broadcasting
1011+
if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
1012+
1013+
npy_intp dims[2];
1014+
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1015+
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1016+
1017+
// Check if we need to allocate new array
1018+
if((NULL == %(_zout)s)
1019+
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1020+
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
10021021
{
1003-
Py_DECREF(%(_zout)s);
1022+
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
1023+
Py_XDECREF(%(_zout)s);
1024+
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1025+
}
1026+
1027+
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
1028+
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1029+
{
1030+
%(fail)s;
1031+
}
1032+
1033+
} else {
1034+
if (%(_zout)s != %(_z)s)
1035+
{
1036+
Py_XDECREF(%(_zout)s);
1037+
%(_zout)s = %(_z)s;
1038+
Py_INCREF(%(_zout)s);
10041039
}
1005-
%(_zout)s = %(_z)s;
1006-
Py_INCREF(%(_zout)s);
10071040
}
1008-
Nz = PyArray_DIMS(%(_z)s);
1009-
Sz = PyArray_STRIDES(%(_z)s);
1041+
1042+
Nz = PyArray_DIMS(%(_zout)s);
1043+
Sz = PyArray_STRIDES(%(_zout)s);
10101044
"""
10111045

10121046
setup_z_Nz_Sz_outplace = """
1047+
npy_intp dims[2];
1048+
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1049+
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1050+
1051+
// Check if we need to allocate new array
10131052
if ((NULL == %(_zout)s)
1014-
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
1015-
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
1016-
|| (PyArray_STRIDES(%(_zout)s)[0] <= 0)
1017-
|| (PyArray_STRIDES(%(_zout)s)[1] <= 0)
1018-
|| (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
1019-
|| (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
1020-
|| ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
1021-
&& (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
1053+
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1054+
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
10221055
{
10231056
Py_XDECREF(%(_zout)s);
1024-
npy_intp dims[2];
1025-
dims[0] = PyArray_DIMS(%(_z)s)[0];
1026-
dims[1] = PyArray_DIMS(%(_z)s)[1];
1027-
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
1028-
PyArray_TYPE(%(_z)s));
1029-
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
1057+
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1058+
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
10301059
if(!%(_zout)s) {
10311060
PyErr_SetString(PyExc_MemoryError,
10321061
"failed to alloc gemm_no_inplace output");
10331062
%(fail)s
10341063
}
10351064
}
1065+
1066+
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
1067+
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1068+
{
1069+
%(fail)s
1070+
}
1071+
10361072
Nz = PyArray_DIMS(%(_zout)s);
10371073
Sz = PyArray_STRIDES(%(_zout)s);
1074+
"""
10381075

1039-
if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
1076+
broadcast_xy = """
1077+
// Broadcast X if needed
1078+
if (Nz[0] > Nx[0])
10401079
{
1041-
float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
1042-
int zoi = Sz[0] / sizeof(float);
1043-
int zoj = Sz[1] / sizeof(float);
1044-
const float * zdata = (float*)PyArray_DATA(%(_z)s);
1045-
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
1046-
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
1047-
for (int i = 0; i < Nz[0]; ++i)
1080+
npy_intp dims[2];
1081+
dims[0] = Nz[0];
1082+
dims[1] = Nx[1];
1083+
// fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]);
1084+
PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1085+
if(!x_new) {
1086+
PyErr_SetString(PyExc_MemoryError,
1087+
"failed to alloc gemm_inplace input");
1088+
%(fail)s
1089+
}
1090+
1091+
if(PyArray_MoveInto(x_new, %(_x)s) == -1)
10481092
{
1049-
for (int j = 0; j < Nz[1]; ++j)
1050-
{
1051-
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1052-
}
1093+
%(fail)s
10531094
}
1095+
1096+
Py_DECREF(%(_x)s);
1097+
%(_x)s = x_new;
1098+
1099+
Nx = PyArray_DIMS(%(_x)s);
1100+
Sx = PyArray_STRIDES(%(_x)s);
10541101
}
1055-
else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
1102+
1103+
// Broadcast Y if needed
1104+
if (Nz[1] > Ny[1])
10561105
{
1057-
double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
1058-
int zoi = Sz[0] / sizeof(double);
1059-
int zoj = Sz[1] / sizeof(double);
1060-
const double * zdata = (double*)PyArray_DATA(%(_z)s);
1061-
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
1062-
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
1063-
for (int i = 0; i < Nz[0]; ++i)
1106+
npy_intp dims[2];
1107+
dims[0] = Ny[0];
1108+
dims[1] = Nz[1];
1109+
// fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]);
1110+
PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1111+
if(!y_new) {
1112+
PyErr_SetString(PyExc_MemoryError,
1113+
"failed to alloc gemm_inplace input");
1114+
%(fail)s
1115+
}
1116+
1117+
if(PyArray_MoveInto(y_new, %(_y)s) == -1)
10641118
{
1065-
for (int j = 0; j < Nz[1]; ++j)
1066-
{
1067-
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1068-
}
1119+
%(fail)s
10691120
}
1121+
1122+
Py_DECREF(%(_y)s);
1123+
%(_y)s = y_new;
1124+
1125+
Ny = PyArray_DIMS(%(_y)s);
1126+
Sy = PyArray_STRIDES(%(_y)s);
10701127
}
1071-
else
1072-
{
1073-
PyErr_SetString(PyExc_AssertionError,
1074-
"neither float nor double dtype");
1075-
%(fail)s
1076-
}
1077-
"""
1128+
1129+
"""
10781130

10791131
case_float_ab_constants = """
10801132
#define REAL float
@@ -1108,7 +1160,7 @@ def c_code(self, node, name, inp, out, sub):
11081160
def c_code_cache_version(self):
11091161
gv = self.build_gemm_version()
11101162
if gv:
1111-
return (6,) + gv
1163+
return (7,) + gv
11121164
else:
11131165
return gv
11141166

@@ -1182,7 +1234,6 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
11821234
if M.owner and M.owner.op == _dot22:
11831235
Ml, Mr = M.owner.inputs
11841236
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
1185-
# print 'GEMM 0', rval, beta, L, alpha, M
11861237
return rval, M
11871238

11881239
# it also might be the case that there is a dimshuffle between the +
@@ -1650,6 +1701,7 @@ def infer_shape(self, fgraph, node, input_shapes):
16501701
Sz = PyArray_STRIDES(%(_zout)s);
16511702
16521703
"""
1704+
broadcast_xy = ""
16531705
check_ab_double_or_float = ""
16541706
case_float_ab_constants = """
16551707
float a = 1.0;
@@ -1933,6 +1985,7 @@ def infer_shape(self, fgraph, node, input_shapes):
19331985
return [[input_shapes[0][0], input_shapes[1][1]]]
19341986

19351987
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
1988+
broadcast_xy = ""
19361989

19371990
check_ab_double_or_float = """
19381991
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)

tests/tensor/test_blas.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import copy
22
from itertools import product
3+
from random import shuffle
34

45
import numpy as np
56
import pytest
@@ -67,6 +68,7 @@
6768
matrix,
6869
row,
6970
scalar,
71+
scalars,
7072
tensor,
7173
tensor3,
7274
tensor4,
@@ -1042,6 +1044,41 @@ def test_inplace1():
10421044
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
10431045

10441046

1047+
@pytest.mark.parametrize("linker", ("py", "cvm"))
1048+
@pytest.mark.parametrize("inplace", (False, True))
1049+
def test_gemm_broadcasting(inplace, linker):
1050+
a, b = scalars("a", "b")
1051+
z, x, y = matrices("z", "x", "y")
1052+
1053+
mode = Mode(linker=linker)
1054+
if inplace:
1055+
out = gemm_inplace(z, a, x, y, b)
1056+
f = aesara.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
1057+
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
1058+
else:
1059+
out = gemm_no_inplace(z, a, x, y, b)
1060+
f = aesara.function([z, x, y, a, b], out, mode=mode)
1061+
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
1062+
1063+
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
1064+
shapes_x = [(5, 4), (1, 4)]
1065+
shapes_y = [(4, 3), (4, 1)]
1066+
1067+
rng = np.random.default_rng()
1068+
shuffle(shapes_z)
1069+
shuffle(shapes_x)
1070+
shuffle(shapes_y)
1071+
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
1072+
z_v = rng.random(size=shape_z).astype(config.floatX)
1073+
x_v = rng.random(size=shape_x).astype(config.floatX)
1074+
y_v = rng.random(size=shape_y).astype(config.floatX)
1075+
# We have to copy for the inplace case
1076+
z_v_np = z_v.copy()
1077+
np.testing.assert_allclose(
1078+
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
1079+
)
1080+
1081+
10451082
def test_dot22():
10461083
for dtype1 in ["float32", "float64", "complex64", "complex128"]:
10471084
a = matrix(dtype=dtype1)
@@ -2476,6 +2513,40 @@ def test_gemm(self):
24762513
Gemm,
24772514
)
24782515

2516+
def test_gemm_broadcast(self):
2517+
rng = np.random.default_rng(unittest_tools.fetch_seed())
2518+
x, y, z = matrices("xyz")
2519+
a = scalar("a")
2520+
b = scalar("b")
2521+
2522+
# Broadcast Z
2523+
self._compile_and_check(
2524+
[x, y, a, z, b],
2525+
[gemm(z, a, x, y, b)],
2526+
[
2527+
rng.random((2, 3)).astype(config.floatX),
2528+
rng.random((3, 4)).astype(config.floatX),
2529+
np.asarray(0.5, dtype=config.floatX),
2530+
rng.random((1, 4)).astype(config.floatX),
2531+
np.asarray(0.5, dtype=config.floatX),
2532+
],
2533+
Gemm,
2534+
)
2535+
2536+
# Broadcast dot(X, Y)
2537+
self._compile_and_check(
2538+
[x, y, a, z, b],
2539+
[gemm(z, a, x, y, b)],
2540+
[
2541+
rng.random((1, 3)).astype(config.floatX),
2542+
rng.random((3, 4)).astype(config.floatX),
2543+
np.asarray(0.5, dtype=config.floatX),
2544+
rng.random((5, 4)).astype(config.floatX),
2545+
np.asarray(1, dtype=config.floatX),
2546+
],
2547+
Gemm,
2548+
)
2549+
24792550
def test_gemv(self):
24802551
rng = np.random.default_rng(unittest_tools.fetch_seed())
24812552
A = matrix("A")

0 commit comments

Comments
 (0)