Skip to content

Commit 0bbf1a4

Browse files
dehorsleyricardoV94
authored andcommitted
Add cholesky of L.LT rewrite
1 parent 3905106 commit 0bbf1a4

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node):
109109
return [x]
110110

111111

112+
@register_canonicalize
113+
@register_stabilize
114+
@node_rewriter([Cholesky])
115+
def cholesky_ldotlt(fgraph, node):
116+
"""
117+
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
118+
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
119+
120+
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
121+
"""
122+
if not isinstance(node.op, Cholesky):
123+
return
124+
125+
A = node.inputs[0]
126+
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
127+
return
128+
129+
l, r = A.owner.inputs
130+
131+
# cholesky(dot(L,L.T)) case
132+
if (
133+
getattr(l.tag, "lower_triangular", False)
134+
and r.owner
135+
and isinstance(r.owner.op, DimShuffle)
136+
and r.owner.op.new_order == (1, 0)
137+
and r.owner.inputs[0] == l
138+
):
139+
if node.op.lower:
140+
return [l]
141+
return [r]
142+
143+
# cholesky(dot(U.T,U)) case
144+
if (
145+
getattr(r.tag, "upper_triangular", False)
146+
and l.owner
147+
and isinstance(l.owner.op, DimShuffle)
148+
and l.owner.op.new_order == (1, 0)
149+
and l.owner.inputs[0] == r
150+
):
151+
if node.op.lower:
152+
return [l]
153+
return [r]
154+
155+
112156
@register_stabilize
113157
@register_specialize
114158
@node_rewriter([Det])

tests/tensor/rewriting/test_linalg.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
22
import numpy.linalg
3+
import pytest
4+
import scipy.linalg
35

46
import pytensor
57
from pytensor import function
68
from pytensor import tensor as at
9+
from pytensor.compile import get_default_mode
710
from pytensor.configdefaults import config
811
from pytensor.tensor.elemwise import DimShuffle
912
from pytensor.tensor.math import _allclose
@@ -105,3 +108,75 @@ def test_matrix_inverse_solve():
105108
node = matrix_inverse(A).dot(b).owner
106109
[out] = inv_as_solve.transform(None, node)
107110
assert isinstance(out.owner.op, Solve)
111+
112+
113+
@pytest.mark.parametrize("tag", ("lower", "upper", None))
114+
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
115+
@pytest.mark.parametrize("product", ("lower", "upper", None))
116+
def test_cholesky_ldotlt(tag, cholesky_form, product):
117+
cholesky = Cholesky(lower=(cholesky_form == "lower"))
118+
119+
transform_removes_chol = tag is not None and product == tag
120+
transform_transposes = transform_removes_chol and cholesky_form != tag
121+
122+
A = matrix("L")
123+
if tag:
124+
setattr(A.tag, tag + "_triangular", True)
125+
126+
if product == "lower":
127+
M = A.dot(A.T)
128+
elif product == "upper":
129+
M = A.T.dot(A)
130+
else:
131+
M = A
132+
133+
C = cholesky(M)
134+
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
135+
136+
print(f.maker.fgraph.apply_nodes)
137+
138+
no_cholesky_in_graph = not any(
139+
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
140+
)
141+
142+
assert no_cholesky_in_graph == transform_removes_chol
143+
144+
if transform_transposes:
145+
assert any(
146+
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
147+
for node in f.maker.fgraph.apply_nodes
148+
)
149+
150+
# Test some concrete value through f
151+
# there must be lower triangular (f assumes they are)
152+
Avs = [
153+
np.eye(1, dtype=pytensor.config.floatX),
154+
np.eye(10, dtype=pytensor.config.floatX),
155+
np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX),
156+
]
157+
if not tag:
158+
# these must be positive def
159+
Avs.extend(
160+
[
161+
np.ones((4, 4), dtype=pytensor.config.floatX)
162+
+ np.eye(4, dtype=pytensor.config.floatX),
163+
]
164+
)
165+
166+
for Av in Avs:
167+
if tag == "upper":
168+
Av = Av.T
169+
170+
if product == "lower":
171+
Mv = Av.dot(Av.T)
172+
elif product == "upper":
173+
Mv = Av.T.dot(Av)
174+
else:
175+
Mv = Av
176+
177+
assert np.all(
178+
np.isclose(
179+
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
180+
f(Av),
181+
)
182+
)

0 commit comments

Comments
 (0)