15
15
from pytensor .tensor .elemwise import DimShuffle
16
16
from pytensor .tensor .rewriting .basic import register_specialize
17
17
from pytensor .tensor .rewriting .linalg import is_matrix_transpose
18
- from pytensor .tensor .slinalg import Solve , lu_factor , lu_solve
18
+ from pytensor .tensor .slinalg import Solve , cho_solve , cholesky , lu_factor , lu_solve
19
19
from pytensor .tensor .variable import TensorVariable
20
20
21
21
22
- def decompose_A (A , assume_a , check_finite ):
22
+ def decompose_A (A , assume_a , check_finite , lower ):
23
23
if assume_a == "gen" :
24
24
return lu_factor (A , check_finite = check_finite )
25
25
elif assume_a == "tridiagonal" :
26
26
# We didn't implement check_finite for tridiagonal LU factorization
27
27
return tridiagonal_lu_factor (A )
28
+ elif assume_a == "pos" :
29
+ return cholesky (A , lower = lower , check_finite = check_finite )
28
30
else :
29
31
raise NotImplementedError
30
32
31
33
32
- def solve_lu_decomposed_system (A_decomp , b , transposed = False , * , core_solve_op : Solve ):
34
+ def solve_decomposed_system (
35
+ A_decomp , b , transposed = False , lower = False , * , core_solve_op : Solve
36
+ ):
33
37
b_ndim = core_solve_op .b_ndim
34
38
check_finite = core_solve_op .check_finite
35
39
assume_a = core_solve_op .assume_a
40
+
36
41
if assume_a == "gen" :
37
42
return lu_solve (
38
43
A_decomp ,
@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
49
54
b_ndim = b_ndim ,
50
55
transposed = transposed ,
51
56
)
57
+ elif assume_a == "pos" :
58
+ # We can ignore the transposed argument here because A is symmetric by assumption
59
+ return cho_solve (
60
+ (A_decomp , lower ),
61
+ b ,
62
+ b_ndim = b_ndim ,
63
+ check_finite = check_finite ,
64
+ )
52
65
else :
53
66
raise NotImplementedError
54
67
55
68
56
- def _split_lu_solve_steps (
69
+ def _split_decomp_and_solve_steps (
57
70
fgraph , node , * , eager : bool , allowed_assume_a : Container [str ]
58
71
):
59
72
if not isinstance (node .op .core_op , Solve ):
@@ -133,13 +146,21 @@ def find_solve_clients(var, assume_a):
133
146
if client .op .core_op .check_finite :
134
147
check_finite_decomp = True
135
148
break
136
- A_decomp = decompose_A (A , assume_a = assume_a , check_finite = check_finite_decomp )
149
+
150
+ lower = node .op .core_op .lower
151
+ A_decomp = decompose_A (
152
+ A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
153
+ )
137
154
138
155
replacements = {}
139
156
for client , transposed in A_solve_clients_and_transpose :
140
157
_ , b = client .inputs
141
- new_x = solve_lu_decomposed_system (
142
- A_decomp , b , transposed = transposed , core_solve_op = client .op .core_op
158
+ new_x = solve_decomposed_system (
159
+ A_decomp ,
160
+ b ,
161
+ transposed = transposed ,
162
+ lower = lower ,
163
+ core_solve_op = client .op .core_op ,
143
164
)
144
165
[old_x ] = client .outputs
145
166
new_x = atleast_Nd (new_x , n = old_x .type .ndim ).astype (old_x .type .dtype )
@@ -149,7 +170,7 @@ def find_solve_clients(var, assume_a):
149
170
return replacements
150
171
151
172
152
- def _scan_split_non_sequence_lu_decomposition_solve (
173
+ def _scan_split_non_sequence_decomposition_and_solve (
153
174
fgraph , node , * , allowed_assume_a : Container [str ]
154
175
):
155
176
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
179
200
non_sequences = {equiv [non_seq ] for non_seq in non_sequences }
180
201
inner_node = equiv [inner_node ] # type: ignore
181
202
182
- replace_dict = _split_lu_solve_steps (
203
+ replace_dict = _split_decomp_and_solve_steps (
183
204
new_scan_fgraph ,
184
205
inner_node ,
185
206
eager = True ,
@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
207
228
208
229
@register_specialize
209
230
@node_rewriter ([Blockwise ])
210
- def reuse_lu_decomposition_multiple_solves (fgraph , node ):
211
- return _split_lu_solve_steps (
212
- fgraph , node , eager = False , allowed_assume_a = {"gen" , "tridiagonal" }
231
+ def reuse_decomposition_multiple_solves (fgraph , node ):
232
+ return _split_decomp_and_solve_steps (
233
+ fgraph , node , eager = False , allowed_assume_a = {"gen" , "tridiagonal" , "pos" }
213
234
)
214
235
215
236
216
237
@node_rewriter ([Scan ])
217
- def scan_split_non_sequence_lu_decomposition_solve (fgraph , node ):
218
- return _scan_split_non_sequence_lu_decomposition_solve (
219
- fgraph , node , allowed_assume_a = {"gen" , "tridiagonal" }
238
+ def scan_split_non_sequence_decomposition_and_solve (fgraph , node ):
239
+ return _scan_split_non_sequence_decomposition_and_solve (
240
+ fgraph , node , allowed_assume_a = {"gen" , "tridiagonal" , "pos" }
220
241
)
221
242
222
243
223
244
scan_seqopt1 .register (
224
- "scan_split_non_sequence_lu_decomposition_solve" ,
225
- in2out (scan_split_non_sequence_lu_decomposition_solve , ignore_newtrees = True ),
245
+ scan_split_non_sequence_decomposition_and_solve . __name__ ,
246
+ in2out (scan_split_non_sequence_decomposition_and_solve , ignore_newtrees = True ),
226
247
"fast_run" ,
227
248
"scan" ,
228
249
"scan_pushout" ,
@@ -231,28 +252,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
231
252
232
253
233
254
@node_rewriter ([Blockwise ])
234
- def reuse_lu_decomposition_multiple_solves_jax (fgraph , node ):
235
- return _split_lu_solve_steps (fgraph , node , eager = False , allowed_assume_a = {"gen" })
255
+ def reuse_decomposition_multiple_solves_jax (fgraph , node ):
256
+ return _split_decomp_and_solve_steps (
257
+ fgraph , node , eager = False , allowed_assume_a = {"gen" , "pos" }
258
+ )
236
259
237
260
238
261
optdb ["specialize" ].register (
239
- reuse_lu_decomposition_multiple_solves_jax .__name__ ,
240
- in2out (reuse_lu_decomposition_multiple_solves_jax , ignore_newtrees = True ),
262
+ reuse_decomposition_multiple_solves_jax .__name__ ,
263
+ in2out (reuse_decomposition_multiple_solves_jax , ignore_newtrees = True ),
241
264
"jax" ,
242
265
use_db_name_as_tag = False ,
243
266
)
244
267
245
268
246
269
@node_rewriter ([Scan ])
247
- def scan_split_non_sequence_lu_decomposition_solve_jax (fgraph , node ):
248
- return _scan_split_non_sequence_lu_decomposition_solve (
249
- fgraph , node , allowed_assume_a = {"gen" }
270
+ def scan_split_non_sequence_decomposition_and_solve_jax (fgraph , node ):
271
+ return _scan_split_non_sequence_decomposition_and_solve (
272
+ fgraph , node , allowed_assume_a = {"gen" , "pos" }
250
273
)
251
274
252
275
253
276
scan_seqopt1 .register (
254
- scan_split_non_sequence_lu_decomposition_solve_jax .__name__ ,
255
- in2out (scan_split_non_sequence_lu_decomposition_solve_jax , ignore_newtrees = True ),
277
+ scan_split_non_sequence_decomposition_and_solve_jax .__name__ ,
278
+ in2out (scan_split_non_sequence_decomposition_and_solve_jax , ignore_newtrees = True ),
256
279
"jax" ,
257
280
use_db_name_as_tag = False ,
258
281
position = 2 ,
0 commit comments