Skip to content

Commit 862c416

Browse files
Reuse cholesky decomposition with cho_solve in graphs with multiple pt.solve when assume_a = "pos" (#1467)
* Extend decomp+solve rewrite machinery to `assume_a="pos"` * Update rewrite name in test * Refactor tests to be nicer * Respect core op `lower` flag when rewriting to ChoSolve
1 parent cca20eb commit 862c416

File tree

3 files changed

+166
-105
lines changed

3 files changed

+166
-105
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,29 @@
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
1717
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
18-
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
18+
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
1919
from pytensor.tensor.variable import TensorVariable
2020

2121

22-
def decompose_A(A, assume_a, check_finite):
22+
def decompose_A(A, assume_a, check_finite, lower):
2323
if assume_a == "gen":
2424
return lu_factor(A, check_finite=check_finite)
2525
elif assume_a == "tridiagonal":
2626
# We didn't implement check_finite for tridiagonal LU factorization
2727
return tridiagonal_lu_factor(A)
28+
elif assume_a == "pos":
29+
return cholesky(A, lower=lower, check_finite=check_finite)
2830
else:
2931
raise NotImplementedError
3032

3133

32-
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
34+
def solve_decomposed_system(
35+
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
36+
):
3337
b_ndim = core_solve_op.b_ndim
3438
check_finite = core_solve_op.check_finite
3539
assume_a = core_solve_op.assume_a
40+
3641
if assume_a == "gen":
3742
return lu_solve(
3843
A_decomp,
@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
4954
b_ndim=b_ndim,
5055
transposed=transposed,
5156
)
57+
elif assume_a == "pos":
58+
# We can ignore the transposed argument here because A is symmetric by assumption
59+
return cho_solve(
60+
(A_decomp, lower),
61+
b,
62+
b_ndim=b_ndim,
63+
check_finite=check_finite,
64+
)
5265
else:
5366
raise NotImplementedError
5467

5568

56-
def _split_lu_solve_steps(
69+
def _split_decomp_and_solve_steps(
5770
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
5871
):
5972
if not isinstance(node.op.core_op, Solve):
@@ -133,13 +146,21 @@ def find_solve_clients(var, assume_a):
133146
if client.op.core_op.check_finite:
134147
check_finite_decomp = True
135148
break
136-
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
149+
150+
lower = node.op.core_op.lower
151+
A_decomp = decompose_A(
152+
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
153+
)
137154

138155
replacements = {}
139156
for client, transposed in A_solve_clients_and_transpose:
140157
_, b = client.inputs
141-
new_x = solve_lu_decomposed_system(
142-
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
158+
new_x = solve_decomposed_system(
159+
A_decomp,
160+
b,
161+
transposed=transposed,
162+
lower=lower,
163+
core_solve_op=client.op.core_op,
143164
)
144165
[old_x] = client.outputs
145166
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
@@ -149,7 +170,7 @@ def find_solve_clients(var, assume_a):
149170
return replacements
150171

151172

152-
def _scan_split_non_sequence_lu_decomposition_solve(
173+
def _scan_split_non_sequence_decomposition_and_solve(
153174
fgraph, node, *, allowed_assume_a: Container[str]
154175
):
155176
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
179200
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
180201
inner_node = equiv[inner_node] # type: ignore
181202

182-
replace_dict = _split_lu_solve_steps(
203+
replace_dict = _split_decomp_and_solve_steps(
183204
new_scan_fgraph,
184205
inner_node,
185206
eager=True,
@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
207228

208229
@register_specialize
209230
@node_rewriter([Blockwise])
210-
def reuse_lu_decomposition_multiple_solves(fgraph, node):
211-
return _split_lu_solve_steps(
212-
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
231+
def reuse_decomposition_multiple_solves(fgraph, node):
232+
return _split_decomp_and_solve_steps(
233+
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
213234
)
214235

215236

216237
@node_rewriter([Scan])
217-
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
218-
return _scan_split_non_sequence_lu_decomposition_solve(
219-
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
238+
def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
239+
return _scan_split_non_sequence_decomposition_and_solve(
240+
fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
220241
)
221242

222243

223244
scan_seqopt1.register(
224-
"scan_split_non_sequence_lu_decomposition_solve",
225-
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
245+
scan_split_non_sequence_decomposition_and_solve.__name__,
246+
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
226247
"fast_run",
227248
"scan",
228249
"scan_pushout",
@@ -231,28 +252,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
231252

232253

233254
@node_rewriter([Blockwise])
234-
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
235-
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
255+
def reuse_decomposition_multiple_solves_jax(fgraph, node):
256+
return _split_decomp_and_solve_steps(
257+
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
258+
)
236259

237260

238261
optdb["specialize"].register(
239-
reuse_lu_decomposition_multiple_solves_jax.__name__,
240-
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
262+
reuse_decomposition_multiple_solves_jax.__name__,
263+
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
241264
"jax",
242265
use_db_name_as_tag=False,
243266
)
244267

245268

246269
@node_rewriter([Scan])
247-
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
248-
return _scan_split_non_sequence_lu_decomposition_solve(
249-
fgraph, node, allowed_assume_a={"gen"}
270+
def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
271+
return _scan_split_non_sequence_decomposition_and_solve(
272+
fgraph, node, allowed_assume_a={"gen", "pos"}
250273
)
251274

252275

253276
scan_seqopt1.register(
254-
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
255-
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
277+
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
278+
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
256279
"jax",
257280
use_db_name_as_tag=False,
258281
position=2,

0 commit comments

Comments
 (0)