Skip to content

Commit 1e2e7ae

Browse files
committed
Fix tensordot implementation
1 parent f799219 commit 1e2e7ae

File tree

2 files changed

+189
-47
lines changed

2 files changed

+189
-47
lines changed

pytensor/tensor/math.py

Lines changed: 126 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import builtins
22
import warnings
3-
from typing import TYPE_CHECKING, Optional
3+
from collections.abc import Sequence
4+
from typing import TYPE_CHECKING, Optional, Union
45

56
import numpy as np
67

@@ -15,6 +16,7 @@
1516
from pytensor.link.c.type import Generic
1617
from pytensor.misc.safe_asarray import _asarray
1718
from pytensor.printing import pprint
19+
from pytensor.raise_op import Assert
1820
from pytensor.scalar.basic import BinaryScalarOp
1921
from pytensor.tensor.basic import (
2022
alloc,
@@ -47,7 +49,11 @@
4749
)
4850
from pytensor.tensor.type_other import NoneConst
4951
from pytensor.tensor.utils import as_list
50-
from pytensor.tensor.variable import TensorConstant, _tensor_py_operators
52+
from pytensor.tensor.variable import (
53+
TensorConstant,
54+
TensorVariable,
55+
_tensor_py_operators,
56+
)
5157

5258

5359
if TYPE_CHECKING:
@@ -2266,57 +2272,47 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
22662272
)
22672273

22682274

2269-
def tensordot(a, b, axes=2):
2275+
def tensordot(
2276+
a: "ArrayLike", b: "ArrayLike", axes: Union[int, Sequence[Sequence[int]]] = 1
2277+
) -> TensorVariable:
22702278
"""
2271-
Compute a generalized dot product over provided axes.
2279+
Compute tensor dot product along specified axes.
2280+
2281+
Implementation is mostly taken from numpy version 1.26.0
22722282
2273-
Given two tensors a and b, tensordot computes a generalized dot product over
2274-
the provided axes. PyTensor's implementation reduces all expressions to
2275-
matrix or vector dot products and is based on code from Tijmen Tieleman's
2276-
gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
2283+
Given two tensors, `a` and `b`, and an sequence object containing
2284+
two sequence objects, ``(a_axes, b_axes)``, sum the products of
2285+
`a`'s and `b`'s elements (components) over the axes specified by
2286+
``a_axes`` and ``b_axes``. The third argument can be a single non-negative
2287+
integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions
2288+
of `a` and the first ``N`` dimensions of `b` are summed over.
22772289
22782290
Parameters
22792291
----------
2280-
a: symbolic tensor
2281-
The first tensor variable.
2282-
b: symbolic tensor
2283-
The second tensor variable
2284-
axes: int or array-like of length 2
2285-
If an integer, the number of axes to sum over.
2286-
If an array, it must have two array elements containing the axes
2287-
to sum over in each tensor.
2288-
2289-
Note that the default value of 2 is not guaranteed to work
2290-
for all values of a and b, and an error will be raised if
2291-
that is the case. The reason for keeping the default is to
2292-
maintain the same signature as numpy's tensordot function
2293-
(and np.tensordot raises analogous errors for non-compatible
2294-
inputs).
2295-
2296-
If an integer i, it is converted to an array containing
2297-
the last i dimensions of the first tensor and the first
2298-
i dimensions of the second tensor:
2299-
axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
2300-
2301-
If an array, its two elements must contain compatible axes
2302-
of the two tensors. For example, [[1, 2], [2, 0]] means sum
2303-
over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
2304-
(Remember axes are zero-indexed!) The 2nd axis of a and the
2305-
3rd axis of b must have the same shape; the same is true for
2306-
the 3rd axis of a and the 1st axis of b.
2292+
a, b : ArrayLike
2293+
Tensors to "dot".
2294+
2295+
axes : int or (2,) array_like
2296+
* integer_like
2297+
If an int N, sum over the last N axes of `a` and the first N axes
2298+
of `b` in order. The sizes of the corresponding axes must match.
2299+
* (2,) array_like
2300+
Or, a list of axes to be summed over, first sequence applying to `a`,
2301+
second to `b`. Both elements array_like must be of the same length.
23072302
23082303
Returns
23092304
-------
2310-
symbolic tensor
2311-
A tensor with shape equal to the concatenation of a's shape
2312-
(less any dimensions that were summed over) and b's shape
2313-
(less any dimensions that were summed over).
2305+
output : TensorVariable
2306+
The tensor dot product of the input.
2307+
Its shape will be equal to the concatenation of `a` and `b` shapes
2308+
(ignoring the dimensions that were summed over given in ``a_axes``
2309+
and ``b_axes``)
23142310
23152311
Examples
23162312
--------
23172313
It may be helpful to consider an example to see what tensordot does.
2318-
PyTensor's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
2319-
and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2314+
PyTensor's implementation is identical to NumPy's. Here ``a`` has shape (2, 3, 4)
2315+
and ``b`` has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
23202316
note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
23212317
are compatible. The resulting tensor will have shape (2, 5, 6) -- the
23222318
dimensions that are not being summed:
@@ -2347,10 +2343,9 @@ def tensordot(a, b, axes=2):
23472343
true
23482344
23492345
This specific implementation avoids a loop by transposing a and b such that
2350-
the summed axes of a are last and the summed axes of b are first. The
2351-
resulting arrays are reshaped to 2 dimensions (or left as vectors, if
2352-
appropriate) and a matrix or vector dot product is taken. The result is
2353-
reshaped back to the required output dimensions.
2346+
the summed axes of ``a`` are last and the summed axes of ``b`` are first. The
2347+
resulting arrays are reshaped to 2 dimensions and a matrix dot product is taken.
2348+
The result is reshaped back to the required output dimensions.
23542349
23552350
In an extreme case, no axes may be specified. The resulting tensor
23562351
will have shape equal to the concatenation of the shapes of a and b:
@@ -2366,7 +2361,92 @@ def tensordot(a, b, axes=2):
23662361
See the documentation of numpy.tensordot for more examples.
23672362
23682363
"""
2369-
return _tensordot_as_dot(a, b, axes, dot=dot, batched=False)
2364+
try:
2365+
iter(axes)
2366+
except Exception:
2367+
axes_a = list(range(-axes, 0))
2368+
axes_b = list(range(0, axes))
2369+
else:
2370+
axes_a, axes_b = axes
2371+
try:
2372+
na = len(axes_a)
2373+
axes_a = list(axes_a)
2374+
except TypeError:
2375+
axes_a = [axes_a]
2376+
na = 1
2377+
try:
2378+
nb = len(axes_b)
2379+
axes_b = list(axes_b)
2380+
except TypeError:
2381+
axes_b = [axes_b]
2382+
nb = 1
2383+
2384+
a = as_tensor_variable(a)
2385+
b = as_tensor_variable(b)
2386+
as_ = a.shape
2387+
bra = a.broadcastable
2388+
ats = a.type.shape
2389+
nda = a.ndim
2390+
bs = b.shape
2391+
brb = b.broadcastable
2392+
bts = b.type.shape
2393+
ndb = b.ndim
2394+
if na != nb:
2395+
raise ValueError(
2396+
"The number of axes supplied for tensordot must be equal for each tensor. "
2397+
f"Got {na} and {nb} respectively."
2398+
)
2399+
for k in range(na):
2400+
ax_a = axes_a[k]
2401+
ax_b = axes_b[k]
2402+
if ax_a < 0:
2403+
axes_a[k] += nda
2404+
if axes_a[k] < 0 or axes_a[k] >= nda:
2405+
raise ValueError(
2406+
f"Supplied axes {ax_a} for first input of tensordot is out of bounds. "
2407+
f"Input tensor has only ndim={nda}."
2408+
)
2409+
if ax_b < 0:
2410+
axes_b[k] += ndb
2411+
if axes_b[k] < 0 or axes_b[k] >= ndb:
2412+
raise ValueError(
2413+
f"Supplied axes {ax_b} for first input of tensordot is out of bounds. "
2414+
f"Input tensor has only ndim={ndb}."
2415+
)
2416+
if (bra[ax_a] != brb[ax_b]) or (
2417+
ats[ax_a] is not None and bts[ax_b] is not None and ats[ax_a] != bts[ax_b]
2418+
):
2419+
raise ValueError(
2420+
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2421+
"must be multiplied and summed with tensordot."
2422+
)
2423+
elif ats[ax_a] is None or bts[ax_b] is None:
2424+
a = Assert(
2425+
"Input array shape along reduced axes of tensordot are not equal"
2426+
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
2427+
2428+
# Move the axes to sum over to the end of "a"
2429+
# and to the front of "b"
2430+
notin = [k for k in range(nda) if k not in axes_a]
2431+
newaxes_a = notin + axes_a
2432+
N2 = 1
2433+
for axis in axes_a:
2434+
N2 *= as_[axis]
2435+
newshape_a = (cast(prod([as_[ax] for ax in notin]), "int64"), N2)
2436+
olda = [as_[axis] for axis in notin]
2437+
2438+
notin = [k for k in range(ndb) if k not in axes_b]
2439+
newaxes_b = axes_b + notin
2440+
N2 = 1
2441+
for axis in axes_b:
2442+
N2 *= bs[axis]
2443+
newshape_b = (N2, cast(prod([bs[ax] for ax in notin]), "int64"))
2444+
oldb = [bs[axis] for axis in notin]
2445+
2446+
at = a.transpose(newaxes_a).reshape(newshape_a)
2447+
bt = b.transpose(newaxes_b).reshape(newshape_b)
2448+
res = _dot(at, bt)
2449+
return res.reshape(olda + oldb)
23702450

23712451

23722452
def outer(x, y):

tests/tensor/test_math.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
from pytensor.link.c.basic import DualLinker
2525
from pytensor.misc.safe_asarray import _asarray
2626
from pytensor.printing import pprint
27+
from pytensor.raise_op import Assert
2728
from pytensor.tensor import blas, blas_c
2829
from pytensor.tensor.basic import (
2930
as_tensor_variable,
3031
constant,
3132
eye,
3233
get_underlying_scalar_constant_value,
34+
ones,
3335
switch,
3436
)
3537
from pytensor.tensor.blas import Dot22
@@ -2187,8 +2189,9 @@ def test_broadcastable1(self):
21872189
rng = np.random.default_rng(seed=utt.fetch_seed())
21882190
x = TensorType(dtype=config.floatX, shape=(1, None, None))("x")
21892191
y = tensor3("y")
2190-
z = tensordot(x, y)
2192+
z = tensordot(x, y, axes=2)
21912193
assert z.broadcastable == (True, False)
2194+
assert z.type.shape == (1, None)
21922195
f = inplace_func([x, y], z)
21932196
xv = random(1, 3, 4, rng=rng)
21942197
yv = random(3, 4, 5, rng=rng)
@@ -2202,12 +2205,71 @@ def test_broadcastable2(self):
22022205
axes = [[2, 1], [0, 1]]
22032206
z = tensordot(x, y, axes=axes)
22042207
assert z.broadcastable == (True, False)
2208+
assert z.type.shape == (1, None)
22052209
f = inplace_func([x, y], z)
22062210
xv = random(1, 3, 4, rng=rng)
22072211
yv = random(4, 3, 5, rng=rng)
22082212
zv = f(xv, yv)
22092213
assert np.allclose(np.tensordot(xv, yv, axes=axes), zv)
22102214

2215+
def test_type_shape(self):
2216+
x = ones(shape=(7, 3, 2))
2217+
y = ones(
2218+
shape=(
2219+
10,
2220+
2,
2221+
)
2222+
)
2223+
xv = x.eval()
2224+
yv = y.eval()
2225+
sy = tensor("sy", shape=(None, 2))
2226+
axes = [[-1], [-1]]
2227+
z = tensordot(x, y, axes=axes)
2228+
sz = tensordot(x, sy, axes=axes)
2229+
2230+
fg = FunctionGraph([x, y], [z])
2231+
assert not any(isinstance(n, Assert) for n in fg.toposort())
2232+
assert z.type.shape == (7, 3, 10)
2233+
assert z.broadcastable == (False, False, False)
2234+
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval())
2235+
2236+
fg = FunctionGraph([x, sy], [sz])
2237+
assert not any(isinstance(n, Assert) for n in fg.toposort())
2238+
assert sz.type.shape == (7, 3, None)
2239+
assert z.broadcastable == (False, False, False)
2240+
assert np.allclose(np.tensordot(xv, yv, axes=axes), sz.eval({sy: yv}))
2241+
2242+
@pytest.mark.parametrize(
2243+
["axes", "has_assert", "values", "expected_fail"],
2244+
[
2245+
([[1], [2]], False, (np.ones((7, 3, 2)), np.ones((7, 2, 3))), False),
2246+
([[1, 2], [2, 1]], True, (np.ones((7, 3, 2)), np.ones((7, 2, 3))), False),
2247+
([[1, 2], [2, 1]], True, (np.ones((7, 3, 2)), np.ones((7, 5, 3))), True),
2248+
],
2249+
)
2250+
def test_shape_assert(self, axes, has_assert, values, expected_fail):
2251+
x = tensor(shape=(7, 3, None))
2252+
y = tensor(shape=(None, None, 3))
2253+
2254+
xv, yv = values
2255+
2256+
# No assert should be present
2257+
z = tensordot(x, y, axes=axes)
2258+
fg = FunctionGraph([x, y], [z])
2259+
found_asserts = any(isinstance(n.op, Assert) for n in fg.toposort())
2260+
if has_assert:
2261+
assert found_asserts
2262+
else:
2263+
assert not found_asserts
2264+
if expected_fail:
2265+
with pytest.raises(
2266+
AssertionError,
2267+
match="Input array shape along reduced axes of tensordot are not equal",
2268+
):
2269+
z.eval({x: xv, y: yv})
2270+
else:
2271+
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
2272+
22112273

22122274
def test_smallest():
22132275
x = dvector()

0 commit comments

Comments
 (0)