Skip to content

Commit 28c8b7c

Browse files
Add rewrite to lift linear algebra through certain linalg ops
1 parent 197069d commit 28c8b7c

File tree

5 files changed

+136
-6
lines changed

5 files changed

+136
-6
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Optional, cast
77

88
import pytensor.tensor as pt
9-
from pytensor import function
9+
from pytensor.compile import function
1010
from pytensor.compile.function.pfunc import rebuild_collect_shared
1111
from pytensor.compile.mode import optdb
1212
from pytensor.compile.sharedvalue import SharedVariable

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
600600
601601
Returns
602602
-------
603-
U, V, D : matrices
603+
U, V, D : matrices
604604
605605
"""
606606
return SVD(full_matrices, compute_uv)(a)

pytensor/tensor/rewriting/linalg.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.elemwise import DimShuffle
99
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
10-
from pytensor.tensor.nlinalg import MatrixInverse, det
10+
from pytensor.tensor.nlinalg import MatrixInverse, MatrixPinv, det, inv, pinv
1111
from pytensor.tensor.rewriting.basic import (
1212
register_canonicalize,
1313
register_specialize,
1414
register_stabilize,
1515
)
1616
from pytensor.tensor.slinalg import (
17+
BlockDiagonal,
1718
Cholesky,
19+
KroneckerProduct,
1820
Solve,
1921
SolveBase,
22+
block_diag,
2023
cholesky,
24+
kron,
2125
solve,
2226
solve_triangular,
2327
)
@@ -310,3 +314,65 @@ def local_log_prod_sqr(fgraph, node):
310314

311315
# TODO: have a reduction like prod and sum that simply
312316
# returns the sign of the prod multiplication.
317+
318+
319+
def local_inv_kron_to_kron_inv(fgraph, node):
320+
# check if we have a kron
321+
# check if parent node is an inv
322+
# if yes, replace with kron(inv, inv)
323+
324+
pass
325+
326+
327+
def local_chol_kron_to_kron_chol(fgraph, node):
328+
# check if we have a kron
329+
# check if parent node is a cholesky
330+
# if yes, replace with kron(cholesky, cholesky)
331+
332+
pass
333+
334+
335+
@register_specialize
336+
@node_rewriter([Blockwise])
337+
def local_lift_through_linalg(fgraph, node):
338+
"""
339+
Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
340+
341+
Parameters
342+
----------
343+
fgraph
344+
node
345+
346+
Returns
347+
-------
348+
349+
"""
350+
# TODO: Simplify this if we end up Blockwising KroneckerProduct
351+
if isinstance(node.op.core_op, (MatrixInverse, Cholesky, MatrixPinv)):
352+
y = node.inputs[0]
353+
outer_op = node.op
354+
355+
if y.owner and (
356+
isinstance(y.owner.op, Blockwise)
357+
and isinstance(y.owner.op.core_op, BlockDiagonal)
358+
or isinstance(y.owner.op, KroneckerProduct)
359+
):
360+
input_matrices = y.owner.inputs
361+
362+
if isinstance(outer_op.core_op, MatrixInverse):
363+
outer_f = inv
364+
elif isinstance(outer_op.core_op, Cholesky):
365+
outer_f = cholesky
366+
elif isinstance(outer_op.core_op, MatrixPinv):
367+
outer_f = pinv
368+
else:
369+
raise NotImplementedError
370+
371+
inner_matrices = [outer_f(m) for m in input_matrices]
372+
373+
if isinstance(y.owner.op, KroneckerProduct):
374+
return [kron(*inner_matrices)]
375+
elif isinstance(y.owner.op.core_op, BlockDiagonal):
376+
return [block_diag(*inner_matrices)]
377+
else:
378+
raise NotImplementedError

pytensor/tensor/slinalg.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pytensor
1111
import pytensor.tensor as pt
12+
from pytensor.compile.builders import OpFromGraph
1213
from pytensor.graph.basic import Apply
1314
from pytensor.graph.op import Op
1415
from pytensor.tensor import as_tensor_variable
@@ -559,6 +560,14 @@ def eigvalsh(a, b, lower=True):
559560
return Eigvalsh(lower)(a, b)
560561

561562

563+
class KroneckerProduct(OpFromGraph):
564+
"""
565+
Wrapper Op for Kronecker graphs
566+
"""
567+
568+
...
569+
570+
562571
def kron(a, b):
563572
"""Kronecker product.
564573
@@ -578,10 +587,13 @@ def kron(a, b):
578587
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
579588
They don't have the same shape and order when
580589
a.ndim != b.ndim != 2.
581-
582590
"""
591+
592+
# TODO: Revisit this implementation?
593+
583594
a = as_tensor_variable(a)
584595
b = as_tensor_variable(b)
596+
585597
if a.ndim + b.ndim <= 2:
586598
raise TypeError(
587599
"kron: inputs dimensions must sum to 3 or more. "
@@ -598,7 +610,8 @@ def kron(a, b):
598610
(o.shape[0] * o.shape[2], o.shape[1] * o.shape[3])
599611
+ tuple(o.shape[i] for i in range(4, o.ndim))
600612
)
601-
return o
613+
614+
return KroneckerProduct(inputs=[a, b], outputs=[o])(a, b)
602615

603616

604617
class Expm(Op):

tests/tensor/rewriting/test_linalg.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.math import _allclose, dot, matmul
17-
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
17+
from pytensor.tensor.nlinalg import Det, MatrixInverse, MatrixPinv, matrix_inverse
1818
from pytensor.tensor.rewriting.linalg import inv_as_solve
1919
from pytensor.tensor.slinalg import (
20+
BlockDiagonal,
2021
Cholesky,
22+
KroneckerProduct,
2123
Solve,
2224
SolveBase,
2325
SolveTriangular,
@@ -333,3 +335,52 @@ def test_invalid_batched_a(self):
333335
ref_fn(test_a, test_b),
334336
rtol=1e-7 if config.floatX == "float64" else 1e-5,
335337
)
338+
339+
340+
@pytest.mark.parametrize(
341+
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
342+
)
343+
@pytest.mark.parametrize(
344+
"f_op, f",
345+
[
346+
(MatrixInverse, pt.linalg.inv),
347+
(Cholesky, pt.linalg.cholesky),
348+
(MatrixPinv, pt.linalg.pinv),
349+
],
350+
ids=["inv", "cholesky", "pinv"],
351+
)
352+
@pytest.mark.parametrize(
353+
"g_op, g",
354+
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
355+
ids=["block_diag", "kron"],
356+
)
357+
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
358+
A, B = list(map(constructor, "ab"))
359+
X = f(g(A, B))
360+
361+
f1 = pytensor.function(
362+
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
363+
)
364+
f2 = pytensor.function(
365+
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
366+
)
367+
368+
all_apply_nodes = f1.maker.fgraph.apply_nodes
369+
f_ops = [
370+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
371+
]
372+
g_ops = [
373+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
374+
]
375+
376+
assert len(f_ops) == 2
377+
assert len(g_ops) == 1
378+
379+
test_vals = [
380+
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
381+
]
382+
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
383+
384+
f2(*test_vals)
385+
386+
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals))

0 commit comments

Comments
 (0)