Skip to content

Commit 50146f6

Browse files
committed
Reuse LU decomposition in Solve
1 parent 5335a68 commit 50146f6

File tree

9 files changed

+370
-7
lines changed

9 files changed

+370
-7
lines changed

pytensor/scan/rewriting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,26 +2561,24 @@ def scan_push_out_dot1(fgraph, node):
25612561
position=1,
25622562
)
25632563

2564-
25652564
scan_seqopt1.register(
25662565
"scan_push_out_non_seq",
25672566
in2out(scan_push_out_non_seq, ignore_newtrees=True),
25682567
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
25692568
"fast_run",
25702569
"scan",
25712570
"scan_pushout",
2572-
position=2,
2571+
position=3,
25732572
)
25742573

2575-
25762574
scan_seqopt1.register(
25772575
"scan_push_out_seq",
25782576
in2out(scan_push_out_seq, ignore_newtrees=True),
25792577
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
25802578
"fast_run",
25812579
"scan",
25822580
"scan_pushout",
2583-
position=3,
2581+
position=4,
25842582
)
25852583

25862584

@@ -2592,7 +2590,7 @@ def scan_push_out_dot1(fgraph, node):
25922590
"more_mem",
25932591
"scan",
25942592
"scan_pushout",
2595-
position=4,
2593+
position=5,
25962594
)
25972595

25982596

@@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node):
26052603
"more_mem",
26062604
"scan",
26072605
"scan_pushout",
2608-
position=5,
2606+
position=6,
26092607
)
26102608

26112609
scan_eqopt2.register(

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
114114

115115

116116
# isort: off
117+
import pytensor.tensor._linalg
117118
from pytensor.tensor import linalg
118119
from pytensor.tensor import special
119120
from pytensor.tensor import signal

pytensor/tensor/_linalg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites
2+
import pytensor.tensor._linalg.solve
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites in the database
2+
import pytensor.tensor._linalg.solve.rewriting
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from copy import copy
2+
3+
from pytensor.graph import Constant, graph_inputs
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
5+
from pytensor.scan.op import Scan
6+
from pytensor.scan.rewriting import scan_seqopt1
7+
from pytensor.tensor.basic import atleast_Nd
8+
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.elemwise import DimShuffle
10+
from pytensor.tensor.rewriting.basic import register_specialize
11+
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
12+
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
13+
from pytensor.tensor.variable import TensorVariable
14+
15+
16+
def decompose_A(A, assume_a):
17+
if assume_a == "gen":
18+
return lu_factor(A, check_finite=False)
19+
else:
20+
raise NotImplementedError
21+
22+
23+
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
24+
if assume_a == "gen":
25+
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
26+
else:
27+
raise NotImplementedError
28+
29+
30+
def _split_lu_solve_steps(fgraph, node, *, eager: bool):
31+
if not isinstance(node.op.core_op, Solve):
32+
return None
33+
34+
def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
35+
transposed = False
36+
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
37+
if a.owner.op.is_left_expand_dims:
38+
[a] = a.owner.inputs
39+
elif is_matrix_transpose(a):
40+
[a] = a.owner.inputs
41+
transposed = True
42+
return a, transposed
43+
44+
def find_solve_clients(var, assume_a):
45+
clients = []
46+
for cl, idx in fgraph.clients[var]:
47+
if (
48+
idx == 0
49+
and isinstance(cl.op, Blockwise)
50+
and isinstance(cl.op.core_op, Solve)
51+
and (cl.op.core_op.assume_a == assume_a)
52+
):
53+
clients.append(cl)
54+
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
55+
# If it's a left expand_dims, recurse on the output
56+
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
57+
return clients
58+
59+
assume_a = node.op.core_op.assume_a
60+
61+
if assume_a != "gen":
62+
return None
63+
64+
A, _ = get_root_A(node.inputs[0])
65+
66+
# Find Solve using A (or left expand_dims of A)
67+
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
68+
# that to the A_decomp outputs
69+
A_solve_clients_and_transpose = [
70+
(client, False) for client in find_solve_clients(A, assume_a)
71+
]
72+
73+
# Find Solves using A.T
74+
for cl, _ in fgraph.clients[A]:
75+
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
76+
A_T = cl.out
77+
A_solve_clients_and_transpose.extend(
78+
(client, True) for client in find_solve_clients(A_T, assume_a)
79+
)
80+
81+
if not eager and len(A_solve_clients_and_transpose) == 1:
82+
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
83+
# That's a "reuse" inside the inner vectorized loop
84+
batch_ndim = node.op.batch_ndim(node)
85+
(client, _) = A_solve_clients_and_transpose[0]
86+
original_A, b = client.inputs
87+
if not any(
88+
a_bcast and not b_bcast
89+
for a_bcast, b_bcast in zip(
90+
original_A.type.broadcastable[:batch_ndim],
91+
b.type.broadcastable[:batch_ndim],
92+
strict=True,
93+
)
94+
):
95+
return None
96+
97+
A_decomp = decompose_A(A, assume_a=assume_a)
98+
99+
replacements = {}
100+
for client, transposed in A_solve_clients_and_transpose:
101+
_, b = client.inputs
102+
b_ndim = client.op.core_op.b_ndim
103+
new_x = solve_lu_decomposed_system(
104+
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
105+
)
106+
[old_x] = client.outputs
107+
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
108+
copy_stack_trace(old_x, new_x)
109+
replacements[old_x] = new_x
110+
111+
return replacements
112+
113+
114+
@register_specialize
115+
@node_rewriter([Blockwise])
116+
def reuse_lu_decomposition_multiple_solves(fgraph, node):
117+
return _split_lu_solve_steps(fgraph, node, eager=False)
118+
119+
120+
@node_rewriter([Blockwise])
121+
def eager_split_lu_solve_steps(fgraph, node):
122+
return _split_lu_solve_steps(fgraph, node, eager=True)
123+
124+
125+
@node_rewriter([Scan])
126+
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
127+
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
128+
129+
The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
130+
"""
131+
scan_op: Scan = node.op
132+
non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs))
133+
new_scan_fgraph = scan_op.fgraph
134+
135+
changed = False
136+
while True:
137+
for inner_node in new_scan_fgraph.toposort():
138+
if (
139+
isinstance(inner_node.op, Blockwise)
140+
and isinstance(inner_node.op.core_op, Solve)
141+
and inner_node.op.core_op.assume_a == "gen"
142+
):
143+
A, b = inner_node.inputs
144+
if all(
145+
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
146+
for root_inp in graph_inputs([A])
147+
):
148+
if new_scan_fgraph is scan_op.fgraph:
149+
# Clone the first time to avoid mutating the original fgraph
150+
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
151+
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
152+
inner_node = equiv[inner_node]
153+
154+
replace_dict = eager_split_lu_solve_steps.transform(
155+
new_scan_fgraph, inner_node
156+
)
157+
assert isinstance(replace_dict, dict) and len(replace_dict) > 0
158+
new_scan_fgraph.replace_all(replace_dict.items())
159+
changed = True
160+
break # Break to start over with a fresh toposort
161+
else: # no_break
162+
break # Nothing else changed
163+
164+
if not changed:
165+
return
166+
167+
# Return a new scan to indicate that a rewrite was done
168+
new_scan_op = copy(scan_op)
169+
new_scan_op.fgraph = new_scan_fgraph
170+
new_outs = new_scan_op.make_node(*node.inputs).outputs
171+
copy_stack_trace(node.outputs, new_outs)
172+
return new_outs
173+
174+
175+
# TODO: We need a seqopt that happens after stabilize (so that it's triggered before reasonable to include gradients rewrites (cough PyMC))
176+
scan_seqopt1.register(
177+
scan_split_non_sequence_lu_decomposition_solve.__name__,
178+
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
179+
"fast_run",
180+
"scan",
181+
"scan_pushout",
182+
position=2,
183+
)

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
7575
if ndims < 2:
7676
return False
7777
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
78+
79+
# Allow expand_dims on the left of the transpose
80+
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
81+
transpose_order = (
82+
*(["x"] * diff),
83+
*transpose_order,
84+
)
7885
return node.op.new_order == transpose_order
7986
return False
8087

tests/tensor/linalg/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)