98
98
from pytensor .printing import FunctionPrinter , pprint
99
99
from pytensor .scalar import bool as bool_t
100
100
from pytensor .tensor import basic as at
101
+ from pytensor .tensor .basic import expand_dims
101
102
from pytensor .tensor .blas_headers import blas_header_text , blas_header_version
102
103
from pytensor .tensor .elemwise import DimShuffle
103
104
from pytensor .tensor .math import add , mul , neg , sub
104
- from pytensor .tensor .shape import specify_broadcastable
105
+ from pytensor .tensor .shape import shape_padright , specify_broadcastable
105
106
from pytensor .tensor .type import DenseTensorType , TensorType , integer_dtypes , tensor
106
107
from pytensor .utils import memoize
107
108
@@ -1637,48 +1638,53 @@ def c_code_cache_version(self):
1637
1638
1638
1639
class BatchedDot (COp ):
1639
1640
"""
1640
- Computes the batched dot product of two variables:
1641
+ Computes a batch matrix multiply with tensor3 variables
1641
1642
1642
- batched_dot(a, b)[i] = dot (a[i], b[i])
1643
+ batched_dot(a, b)[i] = matmul (a[i], b[i])
1643
1644
"""
1644
1645
1645
1646
__props__ = ()
1647
+ gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
1646
1648
1647
- def make_node (self , * inputs ):
1648
- inputs = list (map (at .as_tensor_variable , inputs ))
1649
+ def make_node (self , x , y ):
1650
+ x = at .as_tensor_variable (x )
1651
+ y = at .as_tensor_variable (y )
1649
1652
1650
- if any (not isinstance (i .type , DenseTensorType ) for i in inputs ):
1653
+ if not (
1654
+ isinstance (x .type , DenseTensorType ) and isinstance (y .type , DenseTensorType )
1655
+ ):
1651
1656
raise NotImplementedError ("Only dense tensor types are supported" )
1652
1657
1653
- if len (inputs ) != 2 :
1654
- raise TypeError (f"Two arguments required, but { len (inputs )} given." )
1655
- if inputs [0 ].ndim not in (2 , 3 ):
1656
- raise TypeError (
1657
- "Input 0 (0-indexed)"
1658
- f" must have ndim of 2 or 3, { int (inputs [0 ].ndim )} given. Consider"
1659
- " calling batched_dot instead."
1660
- )
1661
- if inputs [1 ].ndim not in (2 , 3 ):
1658
+ if not (x .type .ndim == 3 and y .type .ndim == 3 ):
1662
1659
raise TypeError (
1663
- "Input 1 (0-indexed)"
1664
- f" must have ndim of 2 or 3, { int (inputs [1 ].ndim )} given. Consider"
1665
- " calling batched_dot instead."
1660
+ f"Inputs must have 3 ndim, but got has { x .type .ndim } and { y .type .ndim } . "
1661
+ "Consider calling batched_dot instead."
1666
1662
)
1667
1663
1668
- dtype = pytensor .scalar .upcast (* [input .type .dtype for input in inputs ])
1669
- # upcast inputs to common dtype if needed
1670
- upcasted_inputs = [at .cast (input , dtype ) for input in inputs ]
1671
- out_shape = (
1672
- (
1673
- 1
1674
- if inputs [0 ].type .shape [0 ] == 1 or inputs [1 ].type .shape [0 ] == 1
1675
- else None ,
1676
- )
1677
- + inputs [0 ].type .shape [1 :- 1 ]
1678
- + inputs [1 ].type .shape [2 :]
1679
- )
1680
- out_shape = tuple (1 if s == 1 else None for s in out_shape )
1681
- return Apply (self , upcasted_inputs , [tensor (dtype = dtype , shape = out_shape )])
1664
+ def extract_static_dim (dim_x , dim_y ):
1665
+ dims = {dim_x , dim_y } - {None }
1666
+ if len (dims ) > 1 :
1667
+ # BatchedDot doesn't allow broadcasting
1668
+ raise ValueError (
1669
+ f"Static dimensions of BatchedDot don't match, got { x .type .shape } and { y .type .shape } "
1670
+ )
1671
+ elif not dims :
1672
+ return None
1673
+ else :
1674
+ return dims .pop ()
1675
+
1676
+ x_batch_dim , x_row_dim , x_sum_dim = x .type .shape
1677
+ y_batch_dim , y_sum_dim , y_col_dim = y .type .shape
1678
+ batch_dim = extract_static_dim (x_batch_dim , y_batch_dim )
1679
+ # Raise if static sum dimensions do not match
1680
+ _ = extract_static_dim (x_sum_dim , y_sum_dim )
1681
+ out_shape = (batch_dim , x_row_dim , y_col_dim )
1682
+
1683
+ # Change dtype if needed
1684
+ dtype = pytensor .scalar .upcast (x .type .dtype , y .type .dtype )
1685
+ x , y = at .cast (x , dtype ), at .cast (y , dtype )
1686
+ out = tensor (dtype = dtype , shape = out_shape )
1687
+ return Apply (self , [x , y ], [out ])
1682
1688
1683
1689
def perform (self , node , inp , out ):
1684
1690
x , y = inp
@@ -1690,11 +1696,7 @@ def perform(self, node, inp, out):
1690
1696
f" same size in axis 0, but have sizes [{ ', ' .join ([str (i .shape [0 ]) for i in inp ])} ]."
1691
1697
)
1692
1698
1693
- shape = self .infer_shape (None , node , [i .shape for i in inp ])[0 ]
1694
- dtype = node .outputs [0 ].dtype
1695
- z0 = z [0 ] = np .empty (shape , dtype = dtype )
1696
- for i in range (z0 .shape [0 ]):
1697
- z0 [i ] = np .dot (x [i ], y [i ])
1699
+ z [0 ] = np .matmul (x , y )
1698
1700
1699
1701
def c_support_code (self , ** kwargs ):
1700
1702
batch_gemm_defn = """
@@ -1792,14 +1794,6 @@ def c_lib_dirs(self, **kwargs):
1792
1794
def c_header_dirs (self , ** kwargs ):
1793
1795
return ldflags (libs = False , include_dir = True )
1794
1796
1795
- def c_code_cleanup (self , node , name , inputs , outputs , sub ):
1796
- return """
1797
- // clean up views
1798
- Py_XDECREF(xs); xs = 0;
1799
- Py_XDECREF(ys); ys = 0;
1800
- Py_XDECREF(zs); zs = 0;
1801
- """
1802
-
1803
1797
def c_code (self , node , name , inp , out , sub ):
1804
1798
_x , _y = inp
1805
1799
(_z ,) = out
@@ -1832,12 +1826,11 @@ def contiguous(var, ndim):
1832
1826
)
1833
1827
1834
1828
# generate code to allocate output based on runtime input shapes
1835
- z_dims = [f"PyArray_DIMS({ _x } )[0]" ]
1836
- if x_ndim == 3 :
1837
- z_dims .append (f"PyArray_DIMS({ _x } )[1]" )
1838
- if y_ndim == 3 :
1839
- z_dims .append (f"PyArray_DIMS({ _y } )[2]" )
1840
- assert len (z_dims ) == z_ndim
1829
+ z_dims = [
1830
+ f"PyArray_DIMS({ _x } )[0]" ,
1831
+ f"PyArray_DIMS({ _x } )[1]" ,
1832
+ f"PyArray_DIMS({ _y } )[2]" ,
1833
+ ]
1841
1834
1842
1835
z_shape_correct = " && " .join (
1843
1836
"PyArray_DIMS(%s)[%i] == %s" % (_z , i , dim ) for i , dim in enumerate (z_dims )
@@ -1880,76 +1873,26 @@ def contiguous(var, ndim):
1880
1873
)
1881
1874
contiguate = "\n " .join (contiguate )
1882
1875
1883
- def c_dimshuffle (newname , oldname , shape ):
1884
- _fail = fail
1885
- _shape = ", " .join (
1886
- "1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname , axis )
1887
- for axis in shape
1888
- )
1889
- return (
1890
- """{
1891
- npy_intp dims[3] = {%(_shape)s};
1892
- PyArray_Dims newshape = {dims, 3};
1893
- %(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
1894
- if (!%(newname)s)
1895
- %(_fail)s
1896
- // make sure we didn't accidentally copy
1897
- assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
1898
- }"""
1899
- % locals ()
1900
- )
1901
-
1902
- # create tensor3 views for any of x, y, z that are not tensor3, so that
1903
- # we only need to implement the tensor3-tensor3 batched dot product.
1904
- # xs, ys and zs will point to these views, or to the original array if
1905
- # it was already tensor3.
1906
- # in the latter case, we artificially increase the reference count of
1907
- # the original array so that the c_code_cleanup method can decref them
1908
- # all indiscriminately.
1909
- upcast = []
1910
- if x_ndim == 3 :
1911
- upcast .append ("xs = %(_x)s; Py_XINCREF(xs);" )
1912
- elif x_ndim == 2 :
1913
- upcast .append (c_dimshuffle ("xs" , _x , (0 , None , 1 )))
1914
- if y_ndim == 3 :
1915
- upcast .append ("ys = %(_y)s; Py_XINCREF(ys);" )
1916
- elif y_ndim == 2 :
1917
- upcast .append (c_dimshuffle ("ys" , _y , (0 , 1 , None )))
1918
- if z_ndim == 3 :
1919
- upcast .append ("zs = %(_z)s; Py_XINCREF(zs);" )
1920
- else :
1921
- upcast .append (
1922
- c_dimshuffle (
1923
- "zs" ,
1924
- _z ,
1925
- (0 , None if x_ndim == 2 else 1 , None if y_ndim == 2 else 1 ),
1926
- )
1927
- )
1928
- upcast = "\n " .join (upcast ) % locals ()
1929
-
1930
1876
return (
1931
1877
"""
1932
1878
int type_num = PyArray_DESCR(%(_x)s)->type_num;
1933
1879
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
1934
1880
1935
- // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
1936
- PyArrayObject *xs = 0, *ys = 0, *zs = 0;
1937
-
1938
- if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
1881
+ if (PyArray_NDIM(%(_x)s) != 3) {
1939
1882
PyErr_Format(PyExc_NotImplementedError,
1940
- "rank(x) != %(x_ndim)s . rank(x) is %%d.",
1883
+ "rank(x) != 3 . rank(x) is %%d.",
1941
1884
PyArray_NDIM(%(_x)s));
1942
1885
%(fail)s;
1943
1886
}
1944
- if (PyArray_NDIM(%(_y)s) != %(y_ndim)s ) {
1887
+ if (PyArray_NDIM(%(_y)s) != 3 ) {
1945
1888
PyErr_Format(PyExc_NotImplementedError,
1946
- "rank(y) != %(y_ndim)s . rank(y) is %%d.",
1889
+ "rank(y) != 3 . rank(y) is %%d.",
1947
1890
PyArray_NDIM(%(_y)s));
1948
1891
%(fail)s;
1949
1892
}
1950
- if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s ) {
1893
+ if (%(_z)s && PyArray_NDIM(%(_z)s) != 3 ) {
1951
1894
PyErr_Format(PyExc_NotImplementedError,
1952
- "rank(z) != %(z_ndim)s . rank(z) is %%d.",
1895
+ "rank(z) != 3 . rank(z) is %%d.",
1953
1896
PyArray_NDIM(%(_z)s));
1954
1897
%(fail)s;
1955
1898
}
@@ -1958,36 +1901,32 @@ def c_dimshuffle(newname, oldname, shape):
1958
1901
%(allocate)s
1959
1902
// reallocate any noncontiguous arrays or arrays with invalid strides
1960
1903
%(contiguate)s
1961
- // add dims to make sure everything is tensor3
1962
- %(upcast)s
1963
- // from here on, use xs, ys and zs as they are tensor3 and share memory
1964
- // with the original %(_x)s, %(_y)s and %(_z)s arrays.
1965
1904
1966
- if ((PyArray_DESCR(xs )->type_num != NPY_DOUBLE)
1967
- && (PyArray_DESCR(xs )->type_num != NPY_FLOAT))
1905
+ if ((PyArray_DESCR(%(_x)s )->type_num != NPY_DOUBLE)
1906
+ && (PyArray_DESCR(%(_x)s )->type_num != NPY_FLOAT))
1968
1907
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
1969
1908
1970
- if ((PyArray_DESCR(ys )->type_num != NPY_DOUBLE)
1971
- && (PyArray_DESCR(ys )->type_num != NPY_FLOAT))
1909
+ if ((PyArray_DESCR(%(_y)s )->type_num != NPY_DOUBLE)
1910
+ && (PyArray_DESCR(%(_y)s )->type_num != NPY_FLOAT))
1972
1911
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
1973
1912
1974
- if ((PyArray_DESCR(zs )->type_num != NPY_DOUBLE)
1975
- && (PyArray_DESCR(zs )->type_num != NPY_FLOAT))
1913
+ if ((PyArray_DESCR(%(_z)s )->type_num != NPY_DOUBLE)
1914
+ && (PyArray_DESCR(%(_z)s )->type_num != NPY_FLOAT))
1976
1915
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
1977
1916
1978
- if ((PyArray_DESCR(xs) ->type_num != PyArray_DESCR(ys )->type_num)
1979
- ||(PyArray_DESCR(xs) ->type_num != PyArray_DESCR(zs )->type_num))
1917
+ if ((PyArray_DESCR(%(_x)s) ->type_num != PyArray_DESCR(%(_y)s )->type_num)
1918
+ ||(PyArray_DESCR(%(_x)s) ->type_num != PyArray_DESCR(%(_z)s )->type_num))
1980
1919
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
1981
1920
1982
1921
switch (type_num)
1983
1922
{
1984
1923
case NPY_FLOAT:
1985
- if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs )) {
1924
+ if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s )) {
1986
1925
%(fail)s;
1987
1926
}
1988
1927
break;
1989
1928
case NPY_DOUBLE:
1990
- if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs )) {
1929
+ if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s )) {
1991
1930
%(fail)s;
1992
1931
}
1993
1932
break;
@@ -1999,32 +1938,14 @@ def c_dimshuffle(newname, oldname, shape):
1999
1938
def c_code_cache_version (self ):
2000
1939
from pytensor .tensor .blas_headers import blas_header_version
2001
1940
2002
- return (4 , blas_header_version ())
1941
+ return (5 , blas_header_version ())
2003
1942
2004
1943
def grad (self , inp , grads ):
2005
1944
x , y = inp
2006
1945
(gz ,) = grads
2007
- xdim , ydim , gdim = x .type .ndim , y .type .ndim , gz .type .ndim
2008
-
2009
- # grad is a vector, so x is a matrix and y is a matrix
2010
- if gdim == 1 :
2011
- xgrad = gz .dimshuffle (0 , "x" ) * y
2012
- ygrad = gz .dimshuffle (0 , "x" ) * x
2013
-
2014
- # x is a matrix, y is a tensor3, grad is a matrix
2015
- elif xdim == 2 and ydim == 3 :
2016
- xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
2017
- ygrad = x .dimshuffle (0 , 1 , "x" ) * gz .dimshuffle (0 , "x" , 1 )
2018
1946
2019
- # x is a tensor3, y is a matrix, grad is a matrix
2020
- elif xdim == 3 and ydim == 2 :
2021
- xgrad = gz .dimshuffle (0 , 1 , "x" ) * y .dimshuffle (0 , "x" , 1 )
2022
- ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
2023
-
2024
- # x is a tensor3, y is a tensor3, grad is a tensor3
2025
- elif xdim == ydim == 3 :
2026
- xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
2027
- ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
1947
+ xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
1948
+ ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
2028
1949
2029
1950
# If x or y contain broadcastable dimensions but only one of
2030
1951
# them know that a matching dimensions is broadcastable, the
@@ -2105,6 +2026,7 @@ def R_op(self, inputs, eval_points):
2105
2026
+ " to BatchedDot.R_op should have the same shape, but "
2106
2027
f"their shapes are { input_values [i ].shape } and { eval_point_values [i ].shape } , respectively"
2107
2028
)
2029
+
2108
2030
if eval_points [0 ]:
2109
2031
t1 = self (eval_points [0 ], inputs [1 ])
2110
2032
if eval_points [1 ]:
@@ -2118,9 +2040,6 @@ def R_op(self, inputs, eval_points):
2118
2040
return [t2 ]
2119
2041
2120
2042
def infer_shape (self , fgraph , node , shapes ):
2121
- for shape_ in shapes :
2122
- if len (shape_ ) not in (2 , 3 ):
2123
- raise NotImplementedError ()
2124
2043
xshp , yshp = shapes
2125
2044
return [xshp [:- 1 ] + yshp [2 :]]
2126
2045
@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
2157
2076
elif b .ndim == 0 :
2158
2077
raise TypeError ("b must have at least one (batch) axis" )
2159
2078
elif a .ndim == 1 :
2160
- return a . dimshuffle ( * ([ 0 ] + [ "x" ] * (b .ndim - 1 ) )) * b
2079
+ return shape_padright ( a , (b .ndim - 1 )) * b
2161
2080
elif b .ndim == 1 :
2162
- return a * b . dimshuffle ( * ([ 0 ] + [ "x" ] * (a .ndim - 1 ) ))
2081
+ return a * shape_padright ( b , (a .ndim - 1 ))
2163
2082
elif a .ndim > 3 or b .ndim > 3 :
2164
2083
return batched_tensordot (a , b , [[a .ndim - 1 ], [np .maximum (1 , b .ndim - 2 )]])
2165
2084
else :
2166
- # avoid circular import
2167
- return _batched_dot (a , b )
2085
+ # If either a or b is a batched vector, expand dims and later squeeze them
2086
+ expanded_axis = []
2087
+ if a .ndim == 2 :
2088
+ a = expand_dims (a , axis = 1 )
2089
+ expanded_axis .append (1 )
2090
+ if b .ndim == 2 :
2091
+ b = expand_dims (b , axis = 2 )
2092
+ expanded_axis .append (2 )
2093
+ out = _batched_dot (a , b )
2094
+ if expanded_axis :
2095
+ out = out .squeeze (axis = expanded_axis )
2096
+ return out
2168
2097
2169
2098
2170
2099
def batched_tensordot (x , y , axes = 2 ):
0 commit comments