Skip to content

Commit 567b8d3

Browse files
Add rewrite to lift linear algebra through certain linalg ops
1 parent 14651fb commit 567b8d3

File tree

5 files changed

+150
-5
lines changed

5 files changed

+150
-5
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import cast
88

99
import pytensor.tensor as pt
10-
from pytensor import function
10+
from pytensor.compile.function import function
1111
from pytensor.compile.function.pfunc import rebuild_collect_shared
1212
from pytensor.compile.mode import optdb
1313
from pytensor.compile.sharedvalue import SharedVariable

pytensor/tensor/nlinalg.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy.core.numeric import normalize_axis_tuple # type: ignore
88

99
from pytensor import scalar as ps
10+
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.gradient import DisconnectedType
1112
from pytensor.graph.basic import Apply
1213
from pytensor.graph.op import Op
@@ -614,7 +615,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
614615
615616
Returns
616617
-------
617-
U, V, D : matrices
618+
U, V, D : matrices
618619
619620
"""
620621
return Blockwise(SVD(full_matrices, compute_uv))(a)
@@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None):
10111012
return TensorSolve(axes)(a, b)
10121013

10131014

1015+
class KroneckerProduct(OpFromGraph):
1016+
"""
1017+
Wrapper Op for Kronecker graphs
1018+
"""
1019+
1020+
10141021
def kron(a, b):
10151022
"""Kronecker product.
10161023
@@ -1042,7 +1049,8 @@ def kron(a, b):
10421049
out_shape = tuple(a.shape * b.shape)
10431050
output_out_of_shape = a_reshaped * b_reshaped
10441051
output_reshaped = output_out_of_shape.reshape(out_shape)
1045-
return output_reshaped
1052+
1053+
return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b)
10461054

10471055

10481056
__all__ = [

pytensor/tensor/rewriting/linalg.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,35 @@
11
import logging
2+
from collections.abc import Callable
23
from typing import cast
34

5+
from pytensor import Variable
6+
from pytensor.graph import Apply, FunctionGraph
47
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
58
from pytensor.tensor.basic import TensorVariable, diagonal
69
from pytensor.tensor.blas import Dot22
710
from pytensor.tensor.blockwise import Blockwise
811
from pytensor.tensor.elemwise import DimShuffle
912
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
10-
from pytensor.tensor.nlinalg import MatrixInverse, det
13+
from pytensor.tensor.nlinalg import (
14+
KroneckerProduct,
15+
MatrixInverse,
16+
MatrixPinv,
17+
det,
18+
inv,
19+
kron,
20+
pinv,
21+
)
1122
from pytensor.tensor.rewriting.basic import (
1223
register_canonicalize,
1324
register_specialize,
1425
register_stabilize,
1526
)
1627
from pytensor.tensor.slinalg import (
28+
BlockDiagonal,
1729
Cholesky,
1830
Solve,
1931
SolveBase,
32+
block_diag,
2033
cholesky,
2134
solve,
2235
solve_triangular,
@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node):
305318

306319
# TODO: have a reduction like prod and sum that simply
307320
# returns the sign of the prod multiplication.
321+
322+
323+
@register_specialize
324+
@node_rewriter([Blockwise])
325+
def local_lift_through_linalg(
326+
fgraph: FunctionGraph, node: Apply
327+
) -> list[Variable] | None:
328+
"""
329+
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
330+
that join matrices (KroneckerProduct, BlockDiagonal).
331+
332+
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
333+
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
334+
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
335+
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
336+
matrices.
337+
338+
Parameters
339+
----------
340+
fgraph: FunctionGraph
341+
Function graph being optimized
342+
node: Apply
343+
Node of the function graph to be optimized
344+
345+
Returns
346+
-------
347+
list of Variable, optional
348+
List of optimized variables, or None if no optimization was performed
349+
"""
350+
351+
# TODO: Simplify this if we end up Blockwising KroneckerProduct
352+
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
353+
y = node.inputs[0]
354+
outer_op = node.op
355+
356+
if y.owner and (
357+
isinstance(y.owner.op, Blockwise)
358+
and isinstance(y.owner.op.core_op, BlockDiagonal)
359+
or isinstance(y.owner.op, KroneckerProduct)
360+
):
361+
input_matrices = y.owner.inputs
362+
363+
if isinstance(outer_op.core_op, MatrixInverse):
364+
outer_f = cast(Callable, inv)
365+
elif isinstance(outer_op.core_op, Cholesky):
366+
outer_f = cast(Callable, cholesky)
367+
elif isinstance(outer_op.core_op, MatrixPinv):
368+
outer_f = cast(Callable, pinv)
369+
else:
370+
raise NotImplementedError # pragma: no cover
371+
372+
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
373+
374+
if isinstance(y.owner.op, KroneckerProduct):
375+
return [kron(*inner_matrices)]
376+
elif isinstance(y.owner.op.core_op, BlockDiagonal):
377+
return [block_diag(*inner_matrices)]
378+
else:
379+
raise NotImplementedError # pragma: no cover

tests/tensor/rewriting/test_linalg.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.math import _allclose, dot, matmul
17-
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
17+
from pytensor.tensor.nlinalg import (
18+
Det,
19+
KroneckerProduct,
20+
MatrixInverse,
21+
MatrixPinv,
22+
matrix_inverse,
23+
)
1824
from pytensor.tensor.rewriting.linalg import inv_as_solve
1925
from pytensor.tensor.slinalg import (
26+
BlockDiagonal,
2027
Cholesky,
2128
Solve,
2229
SolveBase,
@@ -333,3 +340,53 @@ def test_invalid_batched_a(self):
333340
ref_fn(test_a, test_b),
334341
rtol=1e-7 if config.floatX == "float64" else 1e-5,
335342
)
343+
344+
345+
@pytest.mark.parametrize(
346+
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
347+
)
348+
@pytest.mark.parametrize(
349+
"f_op, f",
350+
[
351+
(MatrixInverse, pt.linalg.inv),
352+
(Cholesky, pt.linalg.cholesky),
353+
(MatrixPinv, pt.linalg.pinv),
354+
],
355+
ids=["inv", "cholesky", "pinv"],
356+
)
357+
@pytest.mark.parametrize(
358+
"g_op, g",
359+
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
360+
ids=["block_diag", "kron"],
361+
)
362+
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
363+
if pytensor.config.floatX.endswith("32"):
364+
pytest.skip("Test is flaky at half precision")
365+
366+
A, B = list(map(constructor, "ab"))
367+
X = f(g(A, B))
368+
369+
f1 = pytensor.function(
370+
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
371+
)
372+
f2 = pytensor.function(
373+
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
374+
)
375+
376+
all_apply_nodes = f1.maker.fgraph.apply_nodes
377+
f_ops = [
378+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
379+
]
380+
g_ops = [
381+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
382+
]
383+
384+
assert len(f_ops) == 2
385+
assert len(g_ops) == 1
386+
387+
test_vals = [
388+
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
389+
]
390+
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
391+
392+
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)

tests/tensor/test_nlinalg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,14 @@ def setup_method(self):
590590
self.op = kron
591591
super().setup_method()
592592

593+
def test_vec_vec_kron_raises(self):
594+
x = vector()
595+
y = vector()
596+
with pytest.raises(
597+
TypeError, match="kron: inputs dimensions must sum to 3 or more"
598+
):
599+
kron(x, y)
600+
593601
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
594602
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
595603
def test_perform(self, shp0, shp1):

0 commit comments

Comments
 (0)