Skip to content

Commit d469a61

Browse files
committed
Specialize matmul to batched dot
1 parent c4ff171 commit d469a61

File tree

3 files changed

+88
-9
lines changed

3 files changed

+88
-9
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959

6060
import numpy as np
6161

62+
from pytensor.tensor.rewriting.basic import register_specialize
63+
6264

6365
try:
6466
import numpy.__config__ # noqa
@@ -79,12 +81,12 @@
7981
)
8082
from pytensor.graph.rewriting.db import SequenceDB
8183
from pytensor.graph.utils import InconsistencyError
82-
from pytensor.printing import debugprint
8384
from pytensor.tensor import basic as at
8485
from pytensor.tensor.blas import (
8586
Dot22,
8687
_dot22,
8788
_dot22scalar,
89+
batched_dot,
8890
gemm_inplace,
8991
gemm_no_inplace,
9092
gemv_inplace,
@@ -94,7 +96,7 @@
9496
)
9597
from pytensor.tensor.elemwise import DimShuffle, Elemwise
9698
from pytensor.tensor.exceptions import NotScalarConstantError
97-
from pytensor.tensor.math import Dot, add, mul, neg, sub
99+
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
98100
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
99101
from pytensor.tensor.type import (
100102
DenseTensorType,
@@ -899,9 +901,32 @@ def local_dot22_to_dot22scalar(fgraph, node):
899901
)
900902

901903

902-
# from opt import register_specialize, register_canonicalize
903-
# @register_specialize
904-
@node_rewriter([sub, add])
905-
def local_print_as_we_go_along(fgraph, node):
906-
if node.op in (sub, add):
907-
debugprint(node)
904+
@register_specialize
905+
@node_rewriter([_matrix_matrix_matmul])
906+
def specialize_matmul_to_batched_dot(fgraph, node):
907+
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
908+
909+
TODO: Do the same for Blockwise BatchedDot
910+
"""
911+
x, y = node.inputs
912+
913+
# BatchedDot does not allow implicit broadcasting of the batch dimensions
914+
# We do not want to explicitly broadcast as it may result in huge arrays
915+
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
916+
return None
917+
918+
x_shape = tuple(x.shape)
919+
y_shape = tuple(y.shape)
920+
if len(x_shape) > 3:
921+
# If we have more than one batch dim, ravel it
922+
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
923+
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
924+
925+
new_out = batched_dot(x, y)
926+
927+
if len(x_shape) > 3:
928+
# And then unravel it
929+
new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1]))
930+
931+
copy_stack_trace(node.outputs, [new_out])
932+
return [new_out]

tests/tensor/rewriting/test_blas.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor import function
5+
from pytensor.compile import get_default_mode
6+
from pytensor.tensor import matmul, tensor, vectorize
7+
from pytensor.tensor.blas import BatchedDot
8+
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
10+
11+
12+
@pytest.mark.parametrize("valid_case", (True, False))
13+
def test_specialize_matmul_to_batched_dot(valid_case):
14+
signature = BatchedDot.gufunc_signature
15+
rewrite = specialize_matmul_to_batched_dot.__name__
16+
17+
def core_pt(x, y):
18+
return matmul(x, y)
19+
20+
def core_np(x, y):
21+
return np.matmul(x, y)
22+
23+
x = tensor(shape=(7, 5, 3, 3))
24+
if valid_case:
25+
y = tensor(shape=(7, 5, 3, 3))
26+
else:
27+
y = tensor(shape=(5, 3, 3))
28+
29+
vectorize_pt = function(
30+
[x, y],
31+
vectorize(core_pt, signature=signature)(x, y),
32+
mode=get_default_mode().including(rewrite),
33+
)
34+
blocwkise_node = any(
35+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
36+
)
37+
if valid_case:
38+
assert not blocwkise_node
39+
else:
40+
assert blocwkise_node
41+
42+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
43+
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
44+
vectorize_np = np.vectorize(core_np, signature=signature)
45+
np.testing.assert_allclose(
46+
vectorize_pt(x_test, y_test),
47+
vectorize_np(x_test, y_test),
48+
)

tests/tensor/test_blockwise.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
import pytensor
88
from pytensor import config, function
9+
from pytensor.compile import get_mode
910
from pytensor.gradient import grad
1011
from pytensor.graph import Apply, Op
1112
from pytensor.graph.replace import vectorize_node
1213
from pytensor.raise_op import assert_op
1314
from pytensor.tensor import diagonal, log, tensor
1415
from pytensor.tensor.blockwise import Blockwise
1516
from pytensor.tensor.nlinalg import MatrixInverse
17+
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
1618
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
1719
from pytensor.tensor.utils import _parse_gufunc_signature
1820

@@ -45,7 +47,11 @@ def check_blockwise_runtime_broadcasting(mode):
4547
b = tensor("b", shape=(None, 5, 3))
4648

4749
out = a @ b
48-
fn = function([a, b], out, mode=mode)
50+
fn = function(
51+
[a, b],
52+
out,
53+
mode=get_mode(mode).excluding(specialize_matmul_to_batched_dot.__name__),
54+
)
4955
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
5056

5157
for valid_test_values in [

0 commit comments

Comments
 (0)