1
1
import builtins
2
2
import warnings
3
- from typing import TYPE_CHECKING , Optional
3
+ from collections .abc import Sequence
4
+ from typing import TYPE_CHECKING , Optional , Union
4
5
5
6
import numpy as np
6
7
15
16
from pytensor .link .c .type import Generic
16
17
from pytensor .misc .safe_asarray import _asarray
17
18
from pytensor .printing import pprint
19
+ from pytensor .raise_op import Assert
18
20
from pytensor .scalar .basic import BinaryScalarOp
19
21
from pytensor .tensor .basic import (
20
22
alloc ,
47
49
)
48
50
from pytensor .tensor .type_other import NoneConst
49
51
from pytensor .tensor .utils import as_list
50
- from pytensor .tensor .variable import TensorConstant , _tensor_py_operators
52
+ from pytensor .tensor .variable import (
53
+ TensorConstant ,
54
+ TensorVariable ,
55
+ _tensor_py_operators ,
56
+ )
51
57
52
58
53
59
if TYPE_CHECKING :
@@ -2266,57 +2272,47 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
2266
2272
)
2267
2273
2268
2274
2269
- def tensordot (a , b , axes = 2 ):
2275
+ def tensordot (
2276
+ a : "ArrayLike" , b : "ArrayLike" , axes : Union [int , Sequence [Sequence [int ]]] = 1
2277
+ ) -> TensorVariable :
2270
2278
"""
2271
- Compute a generalized dot product over provided axes.
2279
+ Compute tensor dot product along specified axes.
2280
+
2281
+ Implementation is mostly taken from numpy version 1.26.0
2272
2282
2273
- Given two tensors a and b, tensordot computes a generalized dot product over
2274
- the provided axes. PyTensor's implementation reduces all expressions to
2275
- matrix or vector dot products and is based on code from Tijmen Tieleman's
2276
- gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
2283
+ Given two tensors, `a` and `b`, and an sequence object containing
2284
+ two sequence objects, ``(a_axes, b_axes)``, sum the products of
2285
+ `a`'s and `b`'s elements (components) over the axes specified by
2286
+ ``a_axes`` and ``b_axes``. The third argument can be a single non-negative
2287
+ integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions
2288
+ of `a` and the first ``N`` dimensions of `b` are summed over.
2277
2289
2278
2290
Parameters
2279
2291
----------
2280
- a: symbolic tensor
2281
- The first tensor variable.
2282
- b: symbolic tensor
2283
- The second tensor variable
2284
- axes: int or array-like of length 2
2285
- If an integer, the number of axes to sum over.
2286
- If an array, it must have two array elements containing the axes
2287
- to sum over in each tensor.
2288
-
2289
- Note that the default value of 2 is not guaranteed to work
2290
- for all values of a and b, and an error will be raised if
2291
- that is the case. The reason for keeping the default is to
2292
- maintain the same signature as numpy's tensordot function
2293
- (and np.tensordot raises analogous errors for non-compatible
2294
- inputs).
2295
-
2296
- If an integer i, it is converted to an array containing
2297
- the last i dimensions of the first tensor and the first
2298
- i dimensions of the second tensor:
2299
- axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
2300
-
2301
- If an array, its two elements must contain compatible axes
2302
- of the two tensors. For example, [[1, 2], [2, 0]] means sum
2303
- over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
2304
- (Remember axes are zero-indexed!) The 2nd axis of a and the
2305
- 3rd axis of b must have the same shape; the same is true for
2306
- the 3rd axis of a and the 1st axis of b.
2292
+ a, b : ArrayLike
2293
+ Tensors to "dot".
2294
+
2295
+ axes : int or (2,) array_like
2296
+ * integer_like
2297
+ If an int N, sum over the last N axes of `a` and the first N axes
2298
+ of `b` in order. The sizes of the corresponding axes must match.
2299
+ * (2,) array_like
2300
+ Or, a list of axes to be summed over, first sequence applying to `a`,
2301
+ second to `b`. Both elements array_like must be of the same length.
2307
2302
2308
2303
Returns
2309
2304
-------
2310
- symbolic tensor
2311
- A tensor with shape equal to the concatenation of a's shape
2312
- (less any dimensions that were summed over) and b's shape
2313
- (less any dimensions that were summed over).
2305
+ output : TensorVariable
2306
+ The tensor dot product of the input.
2307
+ Its shape will be equal to the concatenation of `a` and `b` shapes
2308
+ (ignoring the dimensions that were summed over given in ``a_axes``
2309
+ and ``b_axes``)
2314
2310
2315
2311
Examples
2316
2312
--------
2317
2313
It may be helpful to consider an example to see what tensordot does.
2318
- PyTensor's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
2319
- and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2314
+ PyTensor's implementation is identical to NumPy's. Here ``a`` has shape (2, 3, 4)
2315
+ and ``b`` has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2320
2316
note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
2321
2317
are compatible. The resulting tensor will have shape (2, 5, 6) -- the
2322
2318
dimensions that are not being summed:
@@ -2347,10 +2343,9 @@ def tensordot(a, b, axes=2):
2347
2343
true
2348
2344
2349
2345
This specific implementation avoids a loop by transposing a and b such that
2350
- the summed axes of a are last and the summed axes of b are first. The
2351
- resulting arrays are reshaped to 2 dimensions (or left as vectors, if
2352
- appropriate) and a matrix or vector dot product is taken. The result is
2353
- reshaped back to the required output dimensions.
2346
+ the summed axes of ``a`` are last and the summed axes of ``b`` are first. The
2347
+ resulting arrays are reshaped to 2 dimensions and a matrix dot product is taken.
2348
+ The result is reshaped back to the required output dimensions.
2354
2349
2355
2350
In an extreme case, no axes may be specified. The resulting tensor
2356
2351
will have shape equal to the concatenation of the shapes of a and b:
@@ -2366,7 +2361,92 @@ def tensordot(a, b, axes=2):
2366
2361
See the documentation of numpy.tensordot for more examples.
2367
2362
2368
2363
"""
2369
- return _tensordot_as_dot (a , b , axes , dot = dot , batched = False )
2364
+ try :
2365
+ iter (axes )
2366
+ except Exception :
2367
+ axes_a = list (range (- axes , 0 ))
2368
+ axes_b = list (range (0 , axes ))
2369
+ else :
2370
+ axes_a , axes_b = axes
2371
+ try :
2372
+ na = len (axes_a )
2373
+ axes_a = list (axes_a )
2374
+ except TypeError :
2375
+ axes_a = [axes_a ]
2376
+ na = 1
2377
+ try :
2378
+ nb = len (axes_b )
2379
+ axes_b = list (axes_b )
2380
+ except TypeError :
2381
+ axes_b = [axes_b ]
2382
+ nb = 1
2383
+
2384
+ a = as_tensor_variable (a )
2385
+ b = as_tensor_variable (b )
2386
+ as_ = a .shape
2387
+ bra = a .broadcastable
2388
+ ats = a .type .shape
2389
+ nda = a .ndim
2390
+ bs = b .shape
2391
+ brb = b .broadcastable
2392
+ bts = b .type .shape
2393
+ ndb = b .ndim
2394
+ if na != nb :
2395
+ raise ValueError (
2396
+ "The number of axes supplied for tensordot must be equal for each tensor. "
2397
+ f"Got { na } and { nb } respectively."
2398
+ )
2399
+ for k in range (na ):
2400
+ ax_a = axes_a [k ]
2401
+ ax_b = axes_b [k ]
2402
+ if ax_a < 0 :
2403
+ axes_a [k ] += nda
2404
+ if axes_a [k ] < 0 or axes_a [k ] >= nda :
2405
+ raise ValueError (
2406
+ f"Supplied axes { ax_a } for first input of tensordot is out of bounds. "
2407
+ f"Input tensor has only ndim={ nda } ."
2408
+ )
2409
+ if ax_b < 0 :
2410
+ axes_b [k ] += ndb
2411
+ if axes_b [k ] < 0 or axes_b [k ] >= ndb :
2412
+ raise ValueError (
2413
+ f"Supplied axes { ax_b } for first input of tensordot is out of bounds. "
2414
+ f"Input tensor has only ndim={ ndb } ."
2415
+ )
2416
+ if (bra [ax_a ] != brb [ax_b ]) or (
2417
+ ats [ax_a ] is not None and bts [ax_b ] is not None and ats [ax_a ] != bts [ax_b ]
2418
+ ):
2419
+ raise ValueError (
2420
+ "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2421
+ "must be multiplied and summed with tensordot."
2422
+ )
2423
+ elif ats [ax_a ] is None or bts [ax_b ] is None :
2424
+ a = Assert (
2425
+ "Input array shape along reduced axes of tensordot are not equal"
2426
+ )(a , eq (a .shape [ax_a ], b .shape [ax_b ]))
2427
+
2428
+ # Move the axes to sum over to the end of "a"
2429
+ # and to the front of "b"
2430
+ notin = [k for k in range (nda ) if k not in axes_a ]
2431
+ newaxes_a = notin + axes_a
2432
+ N2 = 1
2433
+ for axis in axes_a :
2434
+ N2 *= as_ [axis ]
2435
+ newshape_a = (cast (prod ([as_ [ax ] for ax in notin ]), "int64" ), N2 )
2436
+ olda = [as_ [axis ] for axis in notin ]
2437
+
2438
+ notin = [k for k in range (ndb ) if k not in axes_b ]
2439
+ newaxes_b = axes_b + notin
2440
+ N2 = 1
2441
+ for axis in axes_b :
2442
+ N2 *= bs [axis ]
2443
+ newshape_b = (N2 , cast (prod ([bs [ax ] for ax in notin ]), "int64" ))
2444
+ oldb = [bs [axis ] for axis in notin ]
2445
+
2446
+ at = a .transpose (newaxes_a ).reshape (newshape_a )
2447
+ bt = b .transpose (newaxes_b ).reshape (newshape_b )
2448
+ res = _dot (at , bt )
2449
+ return res .reshape (olda + oldb )
2370
2450
2371
2451
2372
2452
def outer (x , y ):
0 commit comments