Skip to content

Commit 39aa123

Browse files
committed
Extend cholesky of triangular dot rewrite to matmul Ops
Also restrict to 2D Dot cases
1 parent 00546b9 commit 39aa123

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.tensor.blas import Dot22
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.elemwise import DimShuffle
9-
from pytensor.tensor.math import Dot, Prod, log, prod
9+
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1010
from pytensor.tensor.nlinalg import MatrixInverse, det
1111
from pytensor.tensor.rewriting.basic import (
1212
register_canonicalize,
@@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node):
168168
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
169169
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
170170
171+
Also works with matmul.
172+
171173
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
172174
"""
173175
if not isinstance(node.op.core_op, Cholesky):
174176
return
175177

176178
A = node.inputs[0]
177-
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
179+
if not (
180+
A.owner is not None
181+
and (
182+
(
183+
isinstance(A.owner.op, (Dot, Dot22))
184+
# This rewrite only applies to matrix Dot
185+
and A.owner.inputs[0].type.ndim == 2
186+
)
187+
or (A.owner.op == _matrix_matrix_matmul)
188+
)
189+
):
178190
return
179191

180192
l, r = A.owner.inputs

tests/tensor/rewriting/test_linalg.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import numpy.linalg
35
import pytest
@@ -9,13 +11,14 @@
911
from pytensor import tensor as at
1012
from pytensor.compile import get_default_mode
1113
from pytensor.configdefaults import config
14+
from pytensor.tensor import swapaxes
1215
from pytensor.tensor.blockwise import Blockwise
1316
from pytensor.tensor.elemwise import DimShuffle
14-
from pytensor.tensor.math import _allclose
17+
from pytensor.tensor.math import _allclose, dot, matmul
1518
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
1619
from pytensor.tensor.rewriting.linalg import inv_as_solve
1720
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
18-
from pytensor.tensor.type import dmatrix, matrix, vector
21+
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
1922
from tests import unittest_tools as utt
2023
from tests.test_rop import break_op
2124

@@ -137,33 +140,38 @@ def test_matrix_inverse_solve():
137140
@pytest.mark.parametrize("tag", ("lower", "upper", None))
138141
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
139142
@pytest.mark.parametrize("product", ("lower", "upper", None))
140-
def test_cholesky_ldotlt(tag, cholesky_form, product):
143+
@pytest.mark.parametrize("op", (dot, matmul))
144+
def test_cholesky_ldotlt(tag, cholesky_form, product, op):
141145
transform_removes_chol = tag is not None and product == tag
142146
transform_transposes = transform_removes_chol and cholesky_form != tag
143147

144-
A = matrix("L")
148+
ndim = 2 if op == dot else 3
149+
A = tensor("L", shape=(None,) * ndim)
145150
if tag:
146151
setattr(A.tag, tag + "_triangular", True)
147152

148153
if product == "lower":
149-
M = A.dot(A.T)
154+
M = op(A, swapaxes(A, -1, -2))
150155
elif product == "upper":
151-
M = A.T.dot(A)
156+
M = op(swapaxes(A, -1, -2), A)
152157
else:
153158
M = A
154159

155160
C = cholesky(M, lower=(cholesky_form == "lower"))
156161
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
157162

158163
no_cholesky_in_graph = not any(
159-
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
164+
isinstance(node.op, Cholesky)
165+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky))
166+
for node in f.maker.fgraph.apply_nodes
160167
)
161168

162169
assert no_cholesky_in_graph == transform_removes_chol
163170

164171
if transform_transposes:
172+
expected_order = (1, 0) if ndim == 2 else (0, 2, 1)
165173
assert any(
166-
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
174+
isinstance(node.op, DimShuffle) and node.op.new_order == expected_order
167175
for node in f.maker.fgraph.apply_nodes
168176
)
169177

@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
183191
]
184192
)
185193

194+
cholesky_vect_fn = np.vectorize(
195+
partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")),
196+
signature="(a, a)->(a, a)",
197+
)
198+
186199
for Av in Avs:
187200
if tag == "upper":
188201
Av = Av.T
@@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
194207
else:
195208
Mv = Av
196209

197-
assert np.all(
198-
np.isclose(
199-
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
200-
f(Av),
201-
)
210+
if ndim == 3:
211+
Av = np.broadcast_to(Av, (5, *Av.shape))
212+
Mv = np.broadcast_to(Mv, (5, *Mv.shape))
213+
214+
np.testing.assert_allclose(
215+
cholesky_vect_fn(Mv),
216+
f(Av),
202217
)
203218

204219

0 commit comments

Comments
 (0)