diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a2418147cf..55a870d550 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -989,3 +989,20 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): "jax", position=0.9, # Run before canonicalization ) + + +@register_canonicalize +@register_stabilize +@node_rewriter([Dot]) +def rewrite_dot_kron(fgraph, node): + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return False + + c = node.inputs[1] + [a, b] = potential_kron.inputs + + m, n = a.type.shape + p, q = b.type.shape + out_clever = pt.expand_dims((b @ c.reshape(shape=(n, q)).T @ a.T).T.ravel(), 1) + return [out_clever] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9dd2a247a8..8201763d61 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -906,3 +906,29 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): f_rewritten = function([x], z_cholesky, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) + + +def test_dot_kron_rewrite(): + m, n, p, q = 3, 4, 6, 7 + a = pt.matrix("a", shape=(m, n)) + b = pt.matrix("b", shape=(p, q)) + c = pt.matrix("c", shape=(n * q, 1)) + out_direct = pt.linalg.kron(a, b) @ c + + # REWRITE TEST + f_direct_rewritten = function([a, b, c], out_direct, mode="FAST_RUN") + nodes = f_direct_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # NUMERIC VALUE TEST + a_test = np.random.rand(m, n).astype(config.floatX) + b_test = np.random.rand(p, q).astype(config.floatX) + c_test = np.random.rand(n * q, 1).astype(config.floatX) + out_direct_val = np.kron(a_test, b_test) @ c_test + out_clever_val = f_direct_rewritten(a_test, b_test, c_test) + assert_allclose( + out_direct_val, + out_clever_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )