Skip to content

Commit 820928f

Browse files
committed
Simplify BatchedDot implementation
The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
1 parent 1e687ad commit 820928f

File tree

5 files changed

+88
-168
lines changed

5 files changed

+88
-168
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
9999
def batched_dot(a, b):
100100
if a.shape[0] != b.shape[0]:
101101
raise TypeError("Shapes must match in the 0-th dimension")
102-
if a.ndim == 2 or b.ndim == 2:
103-
return jnp.einsum("n...j,nj...->n...", a, b)
104-
return jnp.einsum("nij,njk->nik", a, b)
102+
return jnp.matmul(a, b)
105103

106104
return batched_dot
107105

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
895895

896896
@numba_njit
897897
def batched_dot(x, y):
898+
# Numba does not support 3D matmul
899+
# https://github.com/numba/numba/issues/3804
898900
shape = x.shape[:-1] + y.shape[2:]
899901
z0 = np.empty(shape, dtype=dtype)
900902
for i in range(z0.shape[0]):

pytensor/tensor/blas.py

Lines changed: 77 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@
9898
from pytensor.printing import FunctionPrinter, pprint
9999
from pytensor.scalar import bool as bool_t
100100
from pytensor.tensor import basic as at
101+
from pytensor.tensor.basic import expand_dims
101102
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
102103
from pytensor.tensor.elemwise import DimShuffle
103104
from pytensor.tensor.math import add, mul, neg, sub
104-
from pytensor.tensor.shape import specify_broadcastable
105+
from pytensor.tensor.shape import shape_padright, specify_broadcastable
105106
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
106107
from pytensor.utils import memoize
107108

@@ -1637,48 +1638,53 @@ def c_code_cache_version(self):
16371638

16381639
class BatchedDot(COp):
16391640
"""
1640-
Computes the batched dot product of two variables:
1641+
Computes a batch matrix-matrix dot with tensor3 variables
16411642
16421643
batched_dot(a, b)[i] = dot(a[i], b[i])
16431644
"""
16441645

16451646
__props__ = ()
1647+
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
16461648

1647-
def make_node(self, *inputs):
1648-
inputs = list(map(at.as_tensor_variable, inputs))
1649+
def make_node(self, x, y):
1650+
x = at.as_tensor_variable(x)
1651+
y = at.as_tensor_variable(y)
16491652

1650-
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
1653+
if not (
1654+
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
1655+
):
16511656
raise NotImplementedError("Only dense tensor types are supported")
16521657

1653-
if len(inputs) != 2:
1654-
raise TypeError(f"Two arguments required, but {len(inputs)} given.")
1655-
if inputs[0].ndim not in (2, 3):
1658+
if not (x.type.ndim == 3 and y.type.ndim == 3):
16561659
raise TypeError(
1657-
"Input 0 (0-indexed)"
1658-
f" must have ndim of 2 or 3, {int(inputs[0].ndim)} given. Consider"
1659-
" calling batched_dot instead."
1660-
)
1661-
if inputs[1].ndim not in (2, 3):
1662-
raise TypeError(
1663-
"Input 1 (0-indexed)"
1664-
f" must have ndim of 2 or 3, {int(inputs[1].ndim)} given. Consider"
1665-
" calling batched_dot instead."
1660+
f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. "
1661+
"Consider calling batched_dot instead."
16661662
)
16671663

1668-
dtype = pytensor.scalar.upcast(*[input.type.dtype for input in inputs])
1669-
# upcast inputs to common dtype if needed
1670-
upcasted_inputs = [at.cast(input, dtype) for input in inputs]
1671-
out_shape = (
1672-
(
1673-
1
1674-
if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1
1675-
else None,
1676-
)
1677-
+ inputs[0].type.shape[1:-1]
1678-
+ inputs[1].type.shape[2:]
1679-
)
1680-
out_shape = tuple(1 if s == 1 else None for s in out_shape)
1681-
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
1664+
def extract_static_dim(dim_x, dim_y):
1665+
dims = {dim_x, dim_y} - {None}
1666+
if len(dims) > 1:
1667+
# BatchedDot doesn't allow broadcasting
1668+
raise ValueError(
1669+
f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}"
1670+
)
1671+
elif not dims:
1672+
return None
1673+
else:
1674+
return dims.pop()
1675+
1676+
x_batch_dim, x_row_dim, x_sum_dim = x.type.shape
1677+
y_batch_dim, y_sum_dim, y_col_dim = y.type.shape
1678+
batch_dim = extract_static_dim(x_batch_dim, y_batch_dim)
1679+
# Raise if static sum dimensions do not match
1680+
_ = extract_static_dim(x_sum_dim, y_sum_dim)
1681+
out_shape = (batch_dim, x_row_dim, y_col_dim)
1682+
1683+
# Change dtype if needed
1684+
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
1685+
x, y = at.cast(x, dtype), at.cast(y, dtype)
1686+
out = tensor(dtype=dtype, shape=out_shape)
1687+
return Apply(self, [x, y], [out])
16821688

16831689
def perform(self, node, inp, out):
16841690
x, y = inp
@@ -1690,11 +1696,7 @@ def perform(self, node, inp, out):
16901696
f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]."
16911697
)
16921698

1693-
shape = self.infer_shape(None, node, [i.shape for i in inp])[0]
1694-
dtype = node.outputs[0].dtype
1695-
z0 = z[0] = np.empty(shape, dtype=dtype)
1696-
for i in range(z0.shape[0]):
1697-
z0[i] = np.dot(x[i], y[i])
1699+
z[0] = np.matmul(x, y)
16981700

16991701
def c_support_code(self, **kwargs):
17001702
batch_gemm_defn = """
@@ -1792,14 +1794,6 @@ def c_lib_dirs(self, **kwargs):
17921794
def c_header_dirs(self, **kwargs):
17931795
return ldflags(libs=False, include_dir=True)
17941796

1795-
def c_code_cleanup(self, node, name, inputs, outputs, sub):
1796-
return """
1797-
// clean up views
1798-
Py_XDECREF(xs); xs = 0;
1799-
Py_XDECREF(ys); ys = 0;
1800-
Py_XDECREF(zs); zs = 0;
1801-
"""
1802-
18031797
def c_code(self, node, name, inp, out, sub):
18041798
_x, _y = inp
18051799
(_z,) = out
@@ -1832,12 +1826,11 @@ def contiguous(var, ndim):
18321826
)
18331827

18341828
# generate code to allocate output based on runtime input shapes
1835-
z_dims = [f"PyArray_DIMS({_x})[0]"]
1836-
if x_ndim == 3:
1837-
z_dims.append(f"PyArray_DIMS({_x})[1]")
1838-
if y_ndim == 3:
1839-
z_dims.append(f"PyArray_DIMS({_y})[2]")
1840-
assert len(z_dims) == z_ndim
1829+
z_dims = [
1830+
f"PyArray_DIMS({_x})[0]",
1831+
f"PyArray_DIMS({_x})[1]",
1832+
f"PyArray_DIMS({_y})[2]",
1833+
]
18411834

18421835
z_shape_correct = " && ".join(
18431836
"PyArray_DIMS(%s)[%i] == %s" % (_z, i, dim) for i, dim in enumerate(z_dims)
@@ -1880,76 +1873,26 @@ def contiguous(var, ndim):
18801873
)
18811874
contiguate = "\n".join(contiguate)
18821875

1883-
def c_dimshuffle(newname, oldname, shape):
1884-
_fail = fail
1885-
_shape = ", ".join(
1886-
"1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
1887-
for axis in shape
1888-
)
1889-
return (
1890-
"""{
1891-
npy_intp dims[3] = {%(_shape)s};
1892-
PyArray_Dims newshape = {dims, 3};
1893-
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
1894-
if (!%(newname)s)
1895-
%(_fail)s
1896-
// make sure we didn't accidentally copy
1897-
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
1898-
}"""
1899-
% locals()
1900-
)
1901-
1902-
# create tensor3 views for any of x, y, z that are not tensor3, so that
1903-
# we only need to implement the tensor3-tensor3 batched dot product.
1904-
# xs, ys and zs will point to these views, or to the original array if
1905-
# it was already tensor3.
1906-
# in the latter case, we artificially increase the reference count of
1907-
# the original array so that the c_code_cleanup method can decref them
1908-
# all indiscriminately.
1909-
upcast = []
1910-
if x_ndim == 3:
1911-
upcast.append("xs = %(_x)s; Py_XINCREF(xs);")
1912-
elif x_ndim == 2:
1913-
upcast.append(c_dimshuffle("xs", _x, (0, None, 1)))
1914-
if y_ndim == 3:
1915-
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
1916-
elif y_ndim == 2:
1917-
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
1918-
if z_ndim == 3:
1919-
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
1920-
else:
1921-
upcast.append(
1922-
c_dimshuffle(
1923-
"zs",
1924-
_z,
1925-
(0, None if x_ndim == 2 else 1, None if y_ndim == 2 else 1),
1926-
)
1927-
)
1928-
upcast = "\n".join(upcast) % locals()
1929-
19301876
return (
19311877
"""
19321878
int type_num = PyArray_DESCR(%(_x)s)->type_num;
19331879
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
19341880
1935-
// xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
1936-
PyArrayObject *xs = 0, *ys = 0, *zs = 0;
1937-
1938-
if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
1881+
if (PyArray_NDIM(%(_x)s) != 3) {
19391882
PyErr_Format(PyExc_NotImplementedError,
1940-
"rank(x) != %(x_ndim)s. rank(x) is %%d.",
1883+
"rank(x) != 3. rank(x) is %%d.",
19411884
PyArray_NDIM(%(_x)s));
19421885
%(fail)s;
19431886
}
1944-
if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) {
1887+
if (PyArray_NDIM(%(_y)s) != 3) {
19451888
PyErr_Format(PyExc_NotImplementedError,
1946-
"rank(y) != %(y_ndim)s. rank(y) is %%d.",
1889+
"rank(y) != 3. rank(y) is %%d.",
19471890
PyArray_NDIM(%(_y)s));
19481891
%(fail)s;
19491892
}
1950-
if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) {
1893+
if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) {
19511894
PyErr_Format(PyExc_NotImplementedError,
1952-
"rank(z) != %(z_ndim)s. rank(z) is %%d.",
1895+
"rank(z) != 3. rank(z) is %%d.",
19531896
PyArray_NDIM(%(_z)s));
19541897
%(fail)s;
19551898
}
@@ -1958,36 +1901,32 @@ def c_dimshuffle(newname, oldname, shape):
19581901
%(allocate)s
19591902
// reallocate any noncontiguous arrays or arrays with invalid strides
19601903
%(contiguate)s
1961-
// add dims to make sure everything is tensor3
1962-
%(upcast)s
1963-
// from here on, use xs, ys and zs as they are tensor3 and share memory
1964-
// with the original %(_x)s, %(_y)s and %(_z)s arrays.
19651904
1966-
if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
1967-
&& (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
1905+
if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
1906+
&& (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
19681907
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
19691908
1970-
if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE)
1971-
&& (PyArray_DESCR(ys)->type_num != NPY_FLOAT))
1909+
if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
1910+
&& (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
19721911
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
19731912
1974-
if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE)
1975-
&& (PyArray_DESCR(zs)->type_num != NPY_FLOAT))
1913+
if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE)
1914+
&& (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT))
19761915
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
19771916
1978-
if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num)
1979-
||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
1917+
if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
1918+
||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num))
19801919
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
19811920
19821921
switch (type_num)
19831922
{
19841923
case NPY_FLOAT:
1985-
if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) {
1924+
if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
19861925
%(fail)s;
19871926
}
19881927
break;
19891928
case NPY_DOUBLE:
1990-
if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) {
1929+
if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
19911930
%(fail)s;
19921931
}
19931932
break;
@@ -1999,32 +1938,14 @@ def c_dimshuffle(newname, oldname, shape):
19991938
def c_code_cache_version(self):
20001939
from pytensor.tensor.blas_headers import blas_header_version
20011940

2002-
return (4, blas_header_version())
1941+
return (5, blas_header_version())
20031942

20041943
def grad(self, inp, grads):
20051944
x, y = inp
20061945
(gz,) = grads
2007-
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
20081946

2009-
# grad is a vector, so x is a matrix and y is a matrix
2010-
if gdim == 1:
2011-
xgrad = gz.dimshuffle(0, "x") * y
2012-
ygrad = gz.dimshuffle(0, "x") * x
2013-
2014-
# x is a matrix, y is a tensor3, grad is a matrix
2015-
elif xdim == 2 and ydim == 3:
2016-
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
2017-
ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1)
2018-
2019-
# x is a tensor3, y is a matrix, grad is a matrix
2020-
elif xdim == 3 and ydim == 2:
2021-
xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1)
2022-
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
2023-
2024-
# x is a tensor3, y is a tensor3, grad is a tensor3
2025-
elif xdim == ydim == 3:
2026-
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
2027-
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
1947+
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
1948+
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
20281949

20291950
# If x or y contain broadcastable dimensions but only one of
20301951
# them know that a matching dimensions is broadcastable, the
@@ -2105,6 +2026,7 @@ def R_op(self, inputs, eval_points):
21052026
+ " to BatchedDot.R_op should have the same shape, but "
21062027
f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively"
21072028
)
2029+
21082030
if eval_points[0]:
21092031
t1 = self(eval_points[0], inputs[1])
21102032
if eval_points[1]:
@@ -2118,9 +2040,6 @@ def R_op(self, inputs, eval_points):
21182040
return [t2]
21192041

21202042
def infer_shape(self, fgraph, node, shapes):
2121-
for shape_ in shapes:
2122-
if len(shape_) not in (2, 3):
2123-
raise NotImplementedError()
21242043
xshp, yshp = shapes
21252044
return [xshp[:-1] + yshp[2:]]
21262045

@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
21572076
elif b.ndim == 0:
21582077
raise TypeError("b must have at least one (batch) axis")
21592078
elif a.ndim == 1:
2160-
return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b
2079+
return shape_padright(a, (b.ndim - 1)) * b
21612080
elif b.ndim == 1:
2162-
return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1)))
2081+
return a * shape_padright(b, (a.ndim - 1))
21632082
elif a.ndim > 3 or b.ndim > 3:
21642083
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
21652084
else:
2166-
# avoid circular import
2167-
return _batched_dot(a, b)
2085+
# If either a or b is a batched vector, expand dims and later squeeze them
2086+
expanded_axis = []
2087+
if a.ndim == 2:
2088+
a = expand_dims(a, axis=1)
2089+
expanded_axis.append(1)
2090+
if b.ndim == 2:
2091+
b = expand_dims(b, axis=2)
2092+
expanded_axis.append(2)
2093+
out = _batched_dot(a, b)
2094+
if expanded_axis:
2095+
out = out.squeeze(axis=expanded_axis)
2096+
return out
21682097

21692098

21702099
def batched_tensordot(x, y, axes=2):

tests/link/jax/test_nlinalg.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
4343
with pytest.raises(TypeError):
4444
pytensor_jax_fn(*inputs)
4545

46-
# matrix . matrix
47-
a = matrix("a")
48-
a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3))
49-
b = matrix("b")
50-
b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3))
51-
out = at_blas.BatchedDot()(a, b)
52-
fgraph = FunctionGraph([a, b], [out])
53-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
54-
5546

5647
def test_jax_basic_multiout():
5748
rng = np.random.default_rng(213234)

0 commit comments

Comments
 (0)