Skip to content

Commit d14601d

Browse files
committed
add cholesky of L.LT rewrite
1 parent 82a1e95 commit d14601d

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

pytensor/sandbox/linalg/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,47 @@ def psd_solve_with_chol(fgraph, node):
109109
return [x]
110110

111111

112+
@register_canonicalize
113+
@register_stabilize
114+
@node_rewriter([Cholesky])
115+
def chol_of_dot_chol(fgraph, node):
116+
"""
117+
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
118+
"""
119+
if not isinstance(node.op, Cholesky):
120+
return
121+
122+
A = node.inputs[0]
123+
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
124+
return
125+
126+
l, r = A.owner.inputs
127+
128+
# cholesky(dot(L,L.T)) case
129+
if (
130+
getattr(l.tag, "lower_triangular", False)
131+
and r.owner
132+
and isinstance(r.owner.op, DimShuffle)
133+
and r.owner.op.new_order == (1, 0)
134+
and r.owner.inputs[0] == l
135+
):
136+
if node.op.lower:
137+
return [l]
138+
return [r]
139+
140+
# cholesky(dot(U.T,U)) case
141+
if (
142+
getattr(r.tag, "upper_triangular", False)
143+
and l.owner
144+
and isinstance(l.owner.op, DimShuffle)
145+
and l.owner.op.new_order == (1, 0)
146+
and l.owner.inputs[0] == r
147+
):
148+
if node.op.lower:
149+
return [l]
150+
return [r]
151+
152+
112153
@register_stabilize
113154
@register_specialize
114155
@node_rewriter([Det])

0 commit comments

Comments
 (0)