Skip to content

Commit efd9f49

Browse files
committed
Deprecate BLAS batch helper functions
1 parent b3da2a4 commit efd9f49

File tree

5 files changed

+51
-199
lines changed

5 files changed

+51
-199
lines changed

pytensor/tensor/blas.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@
7979
import logging
8080
import os
8181
import shlex
82+
import warnings
8283
from pathlib import Path
8384

8485
import numpy as np
8586

87+
from pytensor.graph import vectorize_graph
88+
from pytensor.npy_2_compat import normalize_axis_tuple
89+
8690

8791
try:
8892
import numpy.__config__
@@ -100,9 +104,9 @@
100104
from pytensor.printing import FunctionPrinter, pprint
101105
from pytensor.scalar import bool as bool_t
102106
from pytensor.tensor import basic as ptb
103-
from pytensor.tensor.basic import expand_dims
104107
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
105-
from pytensor.tensor.shape import shape_padright, specify_broadcastable
108+
from pytensor.tensor.math import dot, tensordot
109+
from pytensor.tensor.shape import specify_broadcastable
106110
from pytensor.tensor.type import DenseTensorType, tensor
107111

108112

@@ -1604,8 +1608,8 @@ def grad(self, inp, grads):
16041608
x, y = inp
16051609
(gz,) = grads
16061610

1607-
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
1608-
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
1611+
xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1))
1612+
ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz)
16091613

16101614
# If x or y contain broadcastable dimensions but only one of
16111615
# them know that a matching dimensions is broadcastable, the
@@ -1729,31 +1733,22 @@ def batched_dot(a, b):
17291733
dot products in terms of batched matrix-matrix dot products, so
17301734
it may be possible to further optimize for performance.
17311735
"""
1736+
warnings.warn(
1737+
"batched_dot is deprecated. "
1738+
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
1739+
FutureWarning,
1740+
)
17321741
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
17331742

17341743
if a.ndim == 0:
17351744
raise TypeError("a must have at least one (batch) axis")
17361745
elif b.ndim == 0:
17371746
raise TypeError("b must have at least one (batch) axis")
1738-
elif a.ndim == 1:
1739-
return shape_padright(a, (b.ndim - 1)) * b
1740-
elif b.ndim == 1:
1741-
return a * shape_padright(b, (a.ndim - 1))
1742-
elif a.ndim > 3 or b.ndim > 3:
1743-
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
1744-
else:
1745-
# If either a or b is a batched vector, expand dims and later squeeze them
1746-
expanded_axis = []
1747-
if a.ndim == 2:
1748-
a = expand_dims(a, axis=1)
1749-
expanded_axis.append(1)
1750-
if b.ndim == 2:
1751-
b = expand_dims(b, axis=2)
1752-
expanded_axis.append(2)
1753-
out = _batched_dot(a, b)
1754-
if expanded_axis:
1755-
out = out.squeeze(axis=expanded_axis)
1756-
return out
1747+
1748+
core_a = a[0].type()
1749+
core_b = b[0].type()
1750+
core_dot = dot(core_a, core_b)
1751+
return vectorize_graph(core_dot, replace={core_a: a, core_b: b})
17571752

17581753

17591754
def batched_tensordot(x, y, axes=2):
@@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2):
17911786
reshapes to reduce the tensor dot product to a matrix or vector
17921787
dot product. Finally, it calls batched_dot to compute the result.
17931788
"""
1794-
from pytensor.tensor.math import _tensordot_as_dot
1789+
warnings.warn(
1790+
"batched_tensordot is deprecated. "
1791+
"Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`",
1792+
FutureWarning,
1793+
)
1794+
1795+
if isinstance(axes, int):
1796+
core_axes = axes
1797+
else:
1798+
# Convert batched axes to core axes
1799+
core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)]
1800+
core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)]
1801+
core_axes = [core_axes_a, core_axes]
1802+
1803+
core_x = x[0].type()
1804+
core_y = y[0].type()
1805+
core_tensordot = tensordot(core_x, core_y, axes=core_axes)
17951806

1796-
return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
1807+
return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y})

pytensor/tensor/math.py

Lines changed: 1 addition & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
tensor,
5151
uint_dtypes,
5252
)
53-
from pytensor.tensor.utils import as_list, normalize_reduce_axis
53+
from pytensor.tensor.utils import normalize_reduce_axis
5454
from pytensor.tensor.variable import (
5555
TensorVariable,
5656
_tensor_py_operators,
@@ -3208,133 +3208,6 @@ def dense_dot(a, b):
32083208
return _dot(a, b)
32093209

32103210

3211-
def _tensordot_as_dot(a, b, axes, dot, batched):
3212-
"""
3213-
Reduces a tensor dot product to a matrix or vector dot product. Based
3214-
on code from Tijmen Tieleman's gnumpy
3215-
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
3216-
3217-
Please see the documentation of tensordot for the meaning of the a, b
3218-
and axes arguments.
3219-
3220-
:param dot: a function that accepts two symbolic variables and computes
3221-
the appropriate dot product (e.g. dot, batched_dot)
3222-
:type dot: function
3223-
3224-
:param batched: whether to treat the first axis of a and b as a batch
3225-
axis. If so, this axis will be preserved in the output,
3226-
allowing this function to be used also for batched
3227-
tensor dot products.
3228-
:type batched: boolean
3229-
3230-
:returns: a tensor with shape equal to the concatenation of a's shape
3231-
(less any dimensions that were summed over) and b's shape
3232-
(less the first dimension and any dimensions that were summed
3233-
over).
3234-
:rtype: symbolic tensor
3235-
"""
3236-
a, b = as_tensor_variable(a), as_tensor_variable(b)
3237-
3238-
if not np.isscalar(axes) and len(axes) != 2:
3239-
raise ValueError(
3240-
"Axes should be an integer or a "
3241-
f"list/tuple of len 2 ({axes} was provided)"
3242-
)
3243-
3244-
# if 'axes' is a number of axes to multiply and sum over (trailing axes
3245-
# of a, leading axes of b), we can just reshape and use dot.
3246-
elif np.isscalar(axes):
3247-
axes = int(axes)
3248-
3249-
for operand_name, operand in (("a", a), ("b", b)):
3250-
if axes > operand.ndim:
3251-
raise ValueError(
3252-
f"axes can not be larger than the dimension of {operand_name} "
3253-
f"({operand_name}.ndim={operand.ndim}, axes={axes})"
3254-
)
3255-
if batched and axes == operand.ndim:
3256-
raise ValueError(
3257-
"axes to sum over must not include the batch axis "
3258-
f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})"
3259-
)
3260-
3261-
batch_axes = 1 if batched else 0
3262-
a_outaxes = slice(0, a.ndim - axes)
3263-
b_outaxes = slice(batch_axes + axes, b.ndim)
3264-
outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
3265-
outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
3266-
outndim = len(outbcast)
3267-
3268-
a_shape = [1] * 2
3269-
b_shape = [1] * 2
3270-
3271-
# compute total size of summed axes
3272-
for i in range(0, axes):
3273-
a_shape[1] *= a.shape[-(i + 1)]
3274-
b_shape[0] *= b.shape[batch_axes + i]
3275-
# compute total size of other axes
3276-
for i in range(0, a.ndim - axes - batch_axes):
3277-
a_shape[0] *= a.shape[batch_axes + i]
3278-
for i in range(0, b.ndim - axes - batch_axes):
3279-
b_shape[1] *= b.shape[-(i + 1)]
3280-
3281-
if batched:
3282-
a_shape.insert(0, a.shape[0])
3283-
b_shape.insert(0, b.shape[0])
3284-
3285-
a_reshaped = a.reshape(a_shape)
3286-
b_reshaped = b.reshape(b_shape)
3287-
3288-
out_reshaped = dot(a_reshaped, b_reshaped)
3289-
out = out_reshaped.reshape(outshape, ndim=outndim)
3290-
# Make sure the broadcastable pattern of the result is correct,
3291-
# since some shape information can be lost in the reshapes.
3292-
if out.type.broadcastable != outbcast:
3293-
out = specify_broadcastable(
3294-
out, *(ax for (ax, b) in enumerate(outbcast) if b)
3295-
)
3296-
return out
3297-
3298-
# if 'axes' is a list, transpose a and b such that the summed axes of a
3299-
# are last and the summed axes of b are first.
3300-
else:
3301-
axes = [as_list(axes_) for axes_ in axes]
3302-
3303-
if len(axes[0]) != len(axes[1]):
3304-
raise ValueError("Axes elements must have the same length.")
3305-
3306-
for i, (operand_name, operand) in enumerate((("a", a), ("b", b))):
3307-
if len(axes[i]) > operand.ndim:
3308-
raise ValueError(
3309-
f"axes[{i}] should be array_like with length less than "
3310-
f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})."
3311-
)
3312-
if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim:
3313-
raise ValueError(
3314-
f"axes[{i}] contains dimensions greater than or equal "
3315-
f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})."
3316-
)
3317-
if batched and 0 in axes[i]:
3318-
raise ValueError(
3319-
"axes to sum over must not contain the batch axis "
3320-
f"(axes[{i}]={axes[i]})"
3321-
)
3322-
3323-
batch_axes = [0] if batched else []
3324-
other_axes = [
3325-
[x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes]
3326-
for i, operand in enumerate((a, b))
3327-
]
3328-
3329-
a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
3330-
b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])
3331-
3332-
# now that a and b are in the right order, recur with integer axes
3333-
return _tensordot_as_dot(
3334-
a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched
3335-
)
3336-
3337-
33383211
def tensordot(
33393212
a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
33403213
) -> TensorVariable:

pytensor/tensor/rewriting/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
from pytensor.tensor import basic as ptb
8585
from pytensor.tensor.blas import (
8686
Dot22,
87+
_batched_dot,
8788
_dot22,
8889
_dot22scalar,
89-
batched_dot,
9090
gemm_inplace,
9191
gemm_no_inplace,
9292
gemv_inplace,
@@ -928,7 +928,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
928928
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
929929
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
930930

931-
new_out = batched_dot(x, y)
931+
new_out = _batched_dot(x, y)
932932

933933
if len(x_shape) > 3:
934934
# And then unravel it

pytensor/tensor/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,6 @@ def shape_of_variables(
107107
return l
108108

109109

110-
def as_list(x):
111-
"""Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
112-
try:
113-
return list(x)
114-
except TypeError:
115-
return [x]
116-
117-
118110
def import_func_from_string(func_string: str): # -> Optional[Callable]:
119111
func = getattr(np, func_string, None)
120112
if func is not None:

0 commit comments

Comments
 (0)