138
138
except ImportError :
139
139
pass
140
140
141
- from functools import reduce
142
141
from typing import Tuple , Union
143
142
144
143
import aesara .scalar
@@ -630,8 +629,10 @@ def c_header_dirs(self, **kwargs):
630
629
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
631
630
"""
632
631
632
+ # broadcast_xy = None
633
+
633
634
check_dims = """
634
- if (Nx[0] != Nz[0])
635
+ if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
635
636
{
636
637
PyErr_Format(PyExc_ValueError,
637
638
"Shape mismatch: x has %%ld rows but z has %%ld rows",
@@ -645,7 +646,7 @@ def c_header_dirs(self, **kwargs):
645
646
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
646
647
%(fail)s;
647
648
}
648
- if (Ny[1] != Nz[1])
649
+ if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
649
650
{
650
651
PyErr_Format(PyExc_ValueError,
651
652
"Shape mismatch: y has %%ld cols but z has %%ld cols",
@@ -822,14 +823,14 @@ def build_gemm_call(self):
822
823
else :
823
824
setup_z_Nz_Sz = self .setup_z_Nz_Sz
824
825
825
- return reduce (
826
- str .__add__ ,
826
+ return "" .join (
827
827
(
828
828
self .declare_NS ,
829
829
self .check_xyz_rank2 ,
830
830
setup_z_Nz_Sz ,
831
831
self .check_xyz_double_or_float ,
832
832
self .check_ab_double_or_float ,
833
+ self .broadcast_xy ,
833
834
self .check_dims ,
834
835
self .check_strides ,
835
836
self .encode_strides_in_unit ,
@@ -842,8 +843,7 @@ def build_gemm_call(self):
842
843
self .case_double_ab_constants ,
843
844
self .case_double_gemm ,
844
845
self .end_switch_typenum ,
845
- ),
846
- "" ,
846
+ )
847
847
)
848
848
849
849
def build_gemm_version (self ):
@@ -973,6 +973,11 @@ def perform(self, node, inp, out, params):
973
973
z .itemset (z * a + b * np .dot (x , y ))
974
974
zout [0 ] = z
975
975
else :
976
+ # Broadcast Z if needed
977
+ if (x .shape [0 ] > z .shape [0 ]) or (y .shape [1 ] > z .shape [1 ]):
978
+ z = np .broadcast_to (
979
+ z , (max (x .shape [0 ], z .shape [0 ]), max (y .shape [1 ], z .shape [1 ]))
980
+ ).copy ()
976
981
if b == 0.0 :
977
982
if a == 1.0 :
978
983
z [:] = np .dot (x , y )
@@ -993,88 +998,135 @@ def perform(self, node, inp, out, params):
993
998
zout [0 ] = z
994
999
995
1000
def infer_shape (self , fgraph , node , input_shapes ):
996
- return [input_shapes [0 ]]
1001
+ z_shape , _ , x_shape , y_shape , _ = input_shapes
1002
+ return [
1003
+ (
1004
+ aesara .scalar .scalar_maximum (z_shape [0 ], x_shape [0 ]),
1005
+ aesara .scalar .scalar_maximum (z_shape [1 ], y_shape [1 ]),
1006
+ )
1007
+ ]
997
1008
998
1009
setup_z_Nz_Sz_inplace = """
999
- if (%(_zout)s != %(_z)s)
1000
- {
1001
- if (%(_zout)s)
1010
+ // Needs broadcasting
1011
+ if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
1012
+
1013
+ npy_intp dims[2];
1014
+ dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1015
+ dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1016
+
1017
+ // Check if we need to allocate new array
1018
+ if((NULL == %(_zout)s)
1019
+ || (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1020
+ || (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1002
1021
{
1003
- Py_DECREF(%(_zout)s);
1022
+ // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\ n", dims[0], dims[1]);
1023
+ Py_XDECREF(%(_zout)s);
1024
+ %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1025
+ }
1026
+
1027
+ // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\ n", dims[0], dims[1]);
1028
+ if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1029
+ {
1030
+ %(fail)s;
1031
+ }
1032
+
1033
+ } else {
1034
+ if (%(_zout)s != %(_z)s)
1035
+ {
1036
+ Py_XDECREF(%(_zout)s);
1037
+ %(_zout)s = %(_z)s;
1038
+ Py_INCREF(%(_zout)s);
1004
1039
}
1005
- %(_zout)s = %(_z)s;
1006
- Py_INCREF(%(_zout)s);
1007
1040
}
1008
- Nz = PyArray_DIMS(%(_z)s);
1009
- Sz = PyArray_STRIDES(%(_z)s);
1041
+
1042
+ Nz = PyArray_DIMS(%(_zout)s);
1043
+ Sz = PyArray_STRIDES(%(_zout)s);
1010
1044
"""
1011
1045
1012
1046
setup_z_Nz_Sz_outplace = """
1047
+ npy_intp dims[2];
1048
+ dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1049
+ dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1050
+
1051
+ // Check if we need to allocate new array
1013
1052
if ((NULL == %(_zout)s)
1014
- || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
1015
- || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
1016
- || (PyArray_STRIDES(%(_zout)s)[0] <= 0)
1017
- || (PyArray_STRIDES(%(_zout)s)[1] <= 0)
1018
- || (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
1019
- || (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
1020
- || ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
1021
- && (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
1053
+ || (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1054
+ || (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1022
1055
{
1023
1056
Py_XDECREF(%(_zout)s);
1024
- npy_intp dims[2];
1025
- dims[0] = PyArray_DIMS(%(_z)s)[0];
1026
- dims[1] = PyArray_DIMS(%(_z)s)[1];
1027
- %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
1028
- PyArray_TYPE(%(_z)s));
1029
- //fprintf(stderr, "Gemm Allocating %%i %%i\\ n", dims[0], dims[1]);
1057
+ %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1058
+ // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\ n", dims[0], dims[1]);
1030
1059
if(!%(_zout)s) {
1031
1060
PyErr_SetString(PyExc_MemoryError,
1032
1061
"failed to alloc gemm_no_inplace output");
1033
1062
%(fail)s
1034
1063
}
1035
1064
}
1065
+
1066
+ // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\ n", dims[0], dims[1]);
1067
+ if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1068
+ {
1069
+ %(fail)s
1070
+ }
1071
+
1036
1072
Nz = PyArray_DIMS(%(_zout)s);
1037
1073
Sz = PyArray_STRIDES(%(_zout)s);
1074
+ """
1038
1075
1039
- if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
1076
+ broadcast_xy = """
1077
+ // Broadcast X if needed
1078
+ if (Nz[0] > Nx[0])
1040
1079
{
1041
- float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
1042
- int zoi = Sz[0] / sizeof(float);
1043
- int zoj = Sz[1] / sizeof(float);
1044
- const float * zdata = (float*)PyArray_DATA(%(_z)s);
1045
- int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
1046
- int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
1047
- for (int i = 0; i < Nz[0]; ++i)
1080
+ npy_intp dims[2];
1081
+ dims[0] = Nz[0];
1082
+ dims[1] = Nx[1];
1083
+ // fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\ n", dims[0], dims[1]);
1084
+ PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1085
+ if(!x_new) {
1086
+ PyErr_SetString(PyExc_MemoryError,
1087
+ "failed to alloc gemm_inplace input");
1088
+ %(fail)s
1089
+ }
1090
+
1091
+ if(PyArray_MoveInto(x_new, %(_x)s) == -1)
1048
1092
{
1049
- for (int j = 0; j < Nz[1]; ++j)
1050
- {
1051
- zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1052
- }
1093
+ %(fail)s
1053
1094
}
1095
+
1096
+ Py_DECREF(%(_x)s);
1097
+ %(_x)s = x_new;
1098
+
1099
+ Nx = PyArray_DIMS(%(_x)s);
1100
+ Sx = PyArray_STRIDES(%(_x)s);
1054
1101
}
1055
- else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
1102
+
1103
+ // Broadcast Y if needed
1104
+ if (Nz[1] > Ny[1])
1056
1105
{
1057
- double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
1058
- int zoi = Sz[0] / sizeof(double);
1059
- int zoj = Sz[1] / sizeof(double);
1060
- const double * zdata = (double*)PyArray_DATA(%(_z)s);
1061
- int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
1062
- int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
1063
- for (int i = 0; i < Nz[0]; ++i)
1106
+ npy_intp dims[2];
1107
+ dims[0] = Ny[0];
1108
+ dims[1] = Nz[1];
1109
+ // fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\ n", dims[0], dims[1]);
1110
+ PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1111
+ if(!y_new) {
1112
+ PyErr_SetString(PyExc_MemoryError,
1113
+ "failed to alloc gemm_inplace input");
1114
+ %(fail)s
1115
+ }
1116
+
1117
+ if(PyArray_MoveInto(y_new, %(_y)s) == -1)
1064
1118
{
1065
- for (int j = 0; j < Nz[1]; ++j)
1066
- {
1067
- zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1068
- }
1119
+ %(fail)s
1069
1120
}
1121
+
1122
+ Py_DECREF(%(_y)s);
1123
+ %(_y)s = y_new;
1124
+
1125
+ Ny = PyArray_DIMS(%(_y)s);
1126
+ Sy = PyArray_STRIDES(%(_y)s);
1070
1127
}
1071
- else
1072
- {
1073
- PyErr_SetString(PyExc_AssertionError,
1074
- "neither float nor double dtype");
1075
- %(fail)s
1076
- }
1077
- """
1128
+
1129
+ """
1078
1130
1079
1131
case_float_ab_constants = """
1080
1132
#define REAL float
@@ -1108,7 +1160,7 @@ def c_code(self, node, name, inp, out, sub):
1108
1160
def c_code_cache_version (self ):
1109
1161
gv = self .build_gemm_version ()
1110
1162
if gv :
1111
- return (6 ,) + gv
1163
+ return (7 ,) + gv
1112
1164
else :
1113
1165
return gv
1114
1166
@@ -1182,7 +1234,6 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
1182
1234
if M .owner and M .owner .op == _dot22 :
1183
1235
Ml , Mr = M .owner .inputs
1184
1236
rval = [gemm_no_inplace (L , alpha , Ml , Mr , beta )]
1185
- # print 'GEMM 0', rval, beta, L, alpha, M
1186
1237
return rval , M
1187
1238
1188
1239
# it also might be the case that there is a dimshuffle between the +
@@ -1650,6 +1701,7 @@ def infer_shape(self, fgraph, node, input_shapes):
1650
1701
Sz = PyArray_STRIDES(%(_zout)s);
1651
1702
1652
1703
"""
1704
+ broadcast_xy = ""
1653
1705
check_ab_double_or_float = ""
1654
1706
case_float_ab_constants = """
1655
1707
float a = 1.0;
@@ -1933,6 +1985,7 @@ def infer_shape(self, fgraph, node, input_shapes):
1933
1985
return [[input_shapes [0 ][0 ], input_shapes [1 ][1 ]]]
1934
1986
1935
1987
setup_z_Nz_Sz = Dot22 .setup_z_Nz_Sz
1988
+ broadcast_xy = ""
1936
1989
1937
1990
check_ab_double_or_float = """
1938
1991
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
0 commit comments