Skip to content

Commit 8d4054b

Browse files
committed
Remove Matmul Operator in favor of Blockwise Dot
1 parent bfebb84 commit 8d4054b

File tree

5 files changed

+62
-184
lines changed

5 files changed

+62
-184
lines changed

pytensor/tensor/math.py

Lines changed: 19 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
stack,
2626
switch,
2727
)
28+
from pytensor.tensor.blockwise import Blockwise
2829
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
2930
from pytensor.tensor.shape import shape, specify_broadcastable
3031
from pytensor.tensor.type import (
3132
DenseTensorType,
32-
TensorType,
3333
complex_dtypes,
3434
continuous_dtypes,
3535
discrete_dtypes,
@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
28682868
return log(sum(exp(x), axis=axis, keepdims=keepdims))
28692869

28702870

2871-
class MatMul(Op):
2872-
__props__ = ("dtype",)
2873-
2874-
def __init__(self, dtype=None):
2875-
self.dtype = dtype
2876-
2877-
@classmethod
2878-
def _get_output_shape(cls, x1, x2, shapes, validate=False):
2879-
x1_shape, x2_shape = shapes
2880-
2881-
if x1.ndim == 1 and x2.ndim == 1:
2882-
if validate and x1_shape[0] != x2_shape[0]:
2883-
raise ValueError("1d inputs must have the same length.")
2884-
return ()
2885-
elif x1.ndim == 1 and x2.ndim > 1:
2886-
if validate and x1_shape[0] != x2_shape[-2]:
2887-
raise ValueError(
2888-
"length of input 1 must be equal the length "
2889-
"of the 2nd-last dimension of input 2"
2890-
)
2891-
return x2_shape[:-2] + x2_shape[-1:]
2892-
elif x1.ndim > 1 and x2.ndim == 1:
2893-
if validate and x1_shape[-1] != x2_shape[0]:
2894-
raise ValueError(
2895-
"length of input 2 must be equal the length "
2896-
"of the last dimension of input 1"
2897-
)
2898-
return x1_shape[:-1]
2899-
elif x1.ndim == 2 and x2.ndim == 2:
2900-
if validate and x1_shape[-1] != x2_shape[0]:
2901-
raise ValueError(
2902-
"number of columns of input 1 must be equal to "
2903-
"the number of rows of input 2"
2904-
)
2905-
return x1_shape[:-1] + x2_shape[-1:]
2906-
elif x1.ndim > 2 and x2.ndim == 2:
2907-
if validate and x1_shape[-1] != x2_shape[0]:
2908-
raise ValueError(
2909-
"number of rows of input 2 must be equal to "
2910-
"the length of the last dimension of input 1"
2911-
)
2912-
return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
2913-
elif x1.ndim == 2 and x2.ndim > 2:
2914-
if validate and x1_shape[-1] != x2_shape[-2]:
2915-
raise ValueError(
2916-
"number of columns of input 1 must be equal "
2917-
"the length of the 2nd-last dimension of input 2"
2918-
)
2919-
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
2920-
else:
2921-
if validate:
2922-
from pytensor.tensor.random.basic import broadcast_shapes
2923-
2924-
bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2])
2925-
if x1_shape[-1] != x2_shape[-2]:
2926-
raise ValueError(
2927-
"length of the last dimension of input 1 must be equal "
2928-
"to the length of the 2nd-last dimension of input 2"
2929-
)
2930-
else:
2931-
from pytensor.tensor.extra_ops import broadcast_shape
2932-
2933-
bshape = broadcast_shape(
2934-
x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True
2935-
)
2936-
return bshape + x1_shape[-2:-1] + x2_shape[-1:]
2937-
2938-
def make_node(self, a, b):
2939-
a = as_tensor_variable(a)
2940-
b = as_tensor_variable(b)
2941-
2942-
if 0 in {a.ndim, b.ndim}:
2943-
raise ValueError("inputs to `matmul` cannot be scalar.")
2944-
2945-
out_shape = self._get_output_shape(
2946-
a, b, (a.type.shape, b.type.shape), validate=True
2947-
)
2948-
out = TensorType(dtype=self.dtype, shape=out_shape)()
2949-
return Apply(self, [a, b], [out])
2950-
2951-
def perform(self, node, inputs, outputs):
2952-
x1, x2 = inputs
2953-
outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype)
2954-
2955-
def infer_shape(self, fgraph, node, shapes):
2956-
x1, x2 = node.inputs
2957-
return [self._get_output_shape(x1, x2, shapes)]
2871+
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")
29582872

29592873

29602874
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29992913
- Stacks of matrices are broadcast together as if the matrices were elements,
30002914
respecting the signature ``(n, k), (k, m) -> (n, m)``:
30012915
"""
3002-
return MatMul(dtype=dtype)(x1, x2)
2916+
x1 = as_tensor_variable(x1)
2917+
x2 = as_tensor_variable(x2)
2918+
if x1.type.ndim == 0 or x2.type.ndim == 0:
2919+
raise ValueError("matmul operand cannot be scalar")
2920+
if x1.type.ndim == 1 and x2.type.ndim == 1:
2921+
out = _dot(x1, x2)
2922+
elif x1.type.ndim == 1:
2923+
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
2924+
elif x2.type.ndim == 1:
2925+
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
2926+
else:
2927+
out = _matrix_matrix_matmul(x1, x2)
2928+
2929+
if dtype is not None:
2930+
out = out.astype(dtype)
2931+
2932+
return out
30032933

30042934

30052935
__all__ = [

pytensor/tensor/rewriting/blockwise.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from pytensor.graph.replace import vectorize_node
44
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
55
from pytensor.tensor.blockwise import Blockwise
6+
from pytensor.tensor.math import _matrix_matrix_matmul
7+
from pytensor.tensor.rewriting.basic import register_canonicalize
68

79

810
@node_rewriter([Blockwise])
@@ -40,3 +42,10 @@ def local_useless_unbatched_blockwise(fgraph, node):
4042
"blockwise",
4143
position=49,
4244
)
45+
46+
47+
# Avoid redundant cases early on for Ops whose default form is not Blockwised
48+
@register_canonicalize
49+
@node_rewriter(tracks=[_matrix_matrix_matmul])
50+
def local_eager_useless_unbatched_blockwise(fgraph, node):
51+
return local_useless_unbatched_blockwise.fn(fgraph, node)

pytensor/tensor/variable.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,12 @@ def __rdot__(right, left):
647647
return at.math.dense_dot(left, right)
648648

649649
dot = __dot__
650-
__matmul__ = __dot__
651-
__rmatmul__ = __rdot__
650+
651+
def __matmul__(left, right):
652+
return at.math.matmul(left, right)
653+
654+
def __rmatmul__(right, left):
655+
return at.math.matmul(right, left)
652656

653657
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
654658
"""See :func:`pytensor.tensor.math.sum`."""
@@ -797,15 +801,15 @@ def choose(self, choices, mode="raise"):
797801
"""
798802
return at.basic.choose(self, choices, mode="raise")
799803

800-
def squeeze(self):
804+
def squeeze(self, axis=None):
801805
"""
802806
Remove broadcastable dimensions from the shape of an array.
803807
804808
It returns the input array, but with the broadcastable dimensions
805809
removed. This is always `x` itself or a view into `x`.
806810
807811
"""
808-
return at.extra_ops.squeeze(self)
812+
return at.extra_ops.squeeze(self, axis=axis)
809813

810814
def compress(self, a, axis=None):
811815
"""Return selected slices only."""

tests/tensor/test_math.py

Lines changed: 7 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
get_underlying_scalar_constant_value,
3131
switch,
3232
)
33+
from pytensor.tensor.blas import Dot22
3334
from pytensor.tensor.elemwise import CAReduce, Elemwise
3435
from pytensor.tensor.math import (
3536
Argmax,
3637
Dot,
37-
MatMul,
3838
MaxAndArgmax,
3939
Mean,
4040
Prod,
@@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
34123412
assert grad_x_fn(-1e-308) != -np.inf
34133413

34143414

3415-
class TestMatMul(utt.InferShapeTester):
3415+
class TestMatMul:
34163416
def setup_method(self):
3417-
super().setup_method()
34183417
self.rng = np.random.default_rng(utt.fetch_seed())
34193418
self.op = matmul
3420-
self.op_class = MatMul
34213419

34223420
def _validate_output(self, a, b):
34233421
pytensor_sol = self.op(a, b).eval()
@@ -3467,85 +3465,8 @@ def test_dtype_param(self, dtype):
34673465
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
34683466
assert sol.eval().dtype == dtype
34693467

3470-
@pytest.mark.parametrize(
3471-
"x1_shape,x2_shape,exp_res,error_regex",
3472-
[
3473-
((1,), (3,), None, "inputs must have the same length"),
3474-
((2,), (3, 1), None, "length of input 1.*2nd-last dimension of input 2"),
3475-
((2, 5), (3,), None, "length of input 2.*of the last dimension of input 1"),
3476-
(
3477-
(2, 5),
3478-
(3, 4),
3479-
None,
3480-
"number of columns of input 1 .* number of rows of input 2",
3481-
),
3482-
(
3483-
(2, 1, 3),
3484-
(5, 4),
3485-
None,
3486-
"number of rows of input 2 .* last dimension of input 1",
3487-
),
3488-
(
3489-
(2, 5),
3490-
(2, 4, 3),
3491-
None,
3492-
"number of columns of input 1 .* 2nd-last dimension of input 2",
3493-
),
3494-
(
3495-
(3, 2, 4, 5),
3496-
(1, 6, 7),
3497-
None,
3498-
"length of the last dimension of input 1 .* 2nd-last dimension of input 2",
3499-
),
3500-
(
3501-
(4, 5, 4),
3502-
(3, 2, 2),
3503-
None,
3504-
"cannot be broadcast to a single shape",
3505-
),
3506-
(
3507-
(4, None, 2),
3508-
(4, 2, None),
3509-
(4, None, None),
3510-
None,
3511-
),
3512-
],
3513-
)
3514-
def test_get_output_shape(self, x1_shape, x2_shape, exp_res, error_regex):
3515-
x1 = tensor(dtype=np.float64, shape=x1_shape)
3516-
x2 = tensor(dtype=np.float64, shape=x2_shape)
3517-
3518-
if error_regex is not None:
3519-
with pytest.raises(ValueError, match=error_regex):
3520-
self.op_class._get_output_shape(
3521-
x1, x2, (x1_shape, x2_shape), validate=True
3522-
)
3523-
else:
3524-
assert (
3525-
self.op_class._get_output_shape(
3526-
x1, x2, (x1_shape, x2_shape), validate=True
3527-
)
3528-
== exp_res
3529-
)
3530-
3531-
def test_infer_shape(self):
3532-
for shape_x1, shape_x2 in [
3533-
((5,), (5,)),
3534-
((5,), (2, 5, 3)),
3535-
((2, 5, 3), (3,)),
3536-
((2, 5), (5, 4)),
3537-
((2, 5), (2, 5, 3)),
3538-
((2, 1, 3), (3, 4)),
3539-
((3, 2, 4, 5), (1, 5, 7)),
3540-
]:
3541-
a = tensor(dtype=config.floatX, shape=shape_x1)
3542-
b = tensor(dtype=config.floatX, shape=shape_x2)
3543-
x1 = self.rng.random(shape_x1).astype(config.floatX)
3544-
x2 = self.rng.random(shape_x2).astype(config.floatX)
3545-
3546-
self._compile_and_check(
3547-
[a, b],
3548-
[self.op(a, b)],
3549-
[x1, x2],
3550-
self.op_class,
3551-
)
3468+
def test_dot22_opt(self):
3469+
x, y = matrices("xy")
3470+
fn = function([x, y], x @ y, mode="FAST_RUN")
3471+
[node] = fn.maker.fgraph.apply_nodes
3472+
assert isinstance(node.op, Dot22)

tests/tensor/test_variable.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from pytensor.compile.mode import get_default_mode
1111
from pytensor.graph.basic import Constant, equal_computations
1212
from pytensor.tensor import get_vector_length
13-
from pytensor.tensor.basic import constant
13+
from pytensor.tensor.basic import as_tensor, constant
1414
from pytensor.tensor.elemwise import DimShuffle
15-
from pytensor.tensor.math import dot, eq
15+
from pytensor.tensor.math import dot, eq, matmul
1616
from pytensor.tensor.shape import Shape
1717
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
1818
from pytensor.tensor.type import (
@@ -79,16 +79,30 @@ def test_infix_dot_method():
7979
X = dmatrix("X")
8080
y = dvector("y")
8181

82-
res = X @ y
83-
exp_res = X.dot(y)
82+
res = X.dot(y)
83+
exp_res = dot(X, y)
8484
assert equal_computations([res], [exp_res])
8585

8686
X_val = np.arange(2 * 3).reshape((2, 3))
87-
res = X_val @ y
87+
res = as_tensor(X_val).dot(y)
8888
exp_res = dot(X_val, y)
8989
assert equal_computations([res], [exp_res])
9090

9191

92+
def test_infix_matmul_method():
93+
X = dmatrix("X")
94+
y = dvector("y")
95+
96+
res = X @ y
97+
exp_res = matmul(X, y)
98+
assert equal_computations([res], [exp_res])
99+
100+
X_val = np.arange(2 * 3).reshape((2, 3))
101+
res = as_tensor(X_val) @ y
102+
exp_res = matmul(X_val, y)
103+
assert equal_computations([res], [exp_res])
104+
105+
92106
def test_empty_list_indexing():
93107
ynp = np.zeros((2, 2))[:, []]
94108
znp = np.zeros((2, 2))[:, ()]

0 commit comments

Comments
 (0)