Skip to content

Commit 2e6c90f

Browse files
committed
Decompose Tridiagonal Solve into core steps
1 parent 28789ec commit 2e6c90f

File tree

8 files changed

+524
-32
lines changed

8 files changed

+524
-32
lines changed

pytensor/compile/mode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
477477
"fusion",
478478
"inplace",
479479
"scan_save_mem_prealloc",
480+
# There are specific variants for the LU decompositions supported by JAX
481+
"reuse_lu_decomposition_multiple_solves",
482+
"scan_split_non_sequence_lu_decomposition_solve",
480483
],
481484
),
482485
)

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy import ndarray
77
from scipy import linalg
88

9+
from pytensor.link.numba.dispatch import numba_funcify
910
from pytensor.link.numba.dispatch.basic import numba_njit
1011
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1112
_LAPACK,
@@ -20,6 +21,10 @@
2021
_solve_check,
2122
_trans_char_to_int,
2223
)
24+
from pytensor.tensor._linalg.solve.tridiagonal import (
25+
LUFactorTridiagonal,
26+
SolveLUFactorTridiagonal,
27+
)
2328

2429

2530
@numba_njit
@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
3439

3540

3641
def _gttrf(
37-
dl: ndarray, d: ndarray, du: ndarray
42+
dl: ndarray,
43+
d: ndarray,
44+
du: ndarray,
45+
overwrite_dl: bool,
46+
overwrite_d: bool,
47+
overwrite_du: bool,
3848
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
3949
"""Placeholder for LU factorization of tridiagonal matrix."""
4050
return # type: ignore
@@ -45,6 +55,9 @@ def gttrf_impl(
4555
dl: ndarray,
4656
d: ndarray,
4757
du: ndarray,
58+
overwrite_dl: bool,
59+
overwrite_d: bool,
60+
overwrite_du: bool,
4861
) -> Callable[
4962
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]
5063
]:
@@ -60,12 +73,24 @@ def impl(
6073
dl: ndarray,
6174
d: ndarray,
6275
du: ndarray,
76+
overwrite_dl: bool,
77+
overwrite_d: bool,
78+
overwrite_du: bool,
6379
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
6480
n = np.int32(d.shape[-1])
6581
ipiv = np.empty(n, dtype=np.int32)
6682
du2 = np.empty(n - 2, dtype=dtype)
6783
info = val_to_int_ptr(0)
6884

85+
if not overwrite_dl or not dl.flags.f_contiguous:
86+
dl = dl.copy()
87+
88+
if not overwrite_d or not d.flags.f_contiguous:
89+
d = d.copy()
90+
91+
if not overwrite_du or not du.flags.f_contiguous:
92+
du = du.copy()
93+
6994
numba_gttrf(
7095
val_to_int_ptr(n),
7196
dl.view(w_type).ctypes,
@@ -133,10 +158,23 @@ def impl(
133158
nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
134159
info = val_to_int_ptr(0)
135160

136-
if overwrite_b and b.flags.f_contiguous:
137-
b_copy = b
138-
else:
139-
b_copy = _copy_to_fortran_order_even_if_1d(b)
161+
if not overwrite_b or not b.flags.f_contiguous:
162+
b = _copy_to_fortran_order_even_if_1d(b)
163+
164+
if not dl.flags.f_contiguous:
165+
dl = dl.copy()
166+
167+
if not d.flags.f_contiguous:
168+
d = d.copy()
169+
170+
if not du.flags.f_contiguous:
171+
du = du.copy()
172+
173+
if not du2.flags.f_contiguous:
174+
du2 = du2.copy()
175+
176+
if not ipiv.flags.f_contiguous:
177+
ipiv = ipiv.copy()
140178

141179
numba_gttrs(
142180
val_to_int_ptr(_trans_char_to_int(trans)),
@@ -147,12 +185,12 @@ def impl(
147185
du.view(w_type).ctypes,
148186
du2.view(w_type).ctypes,
149187
ipiv.ctypes,
150-
b_copy.view(w_type).ctypes,
188+
b.view(w_type).ctypes,
151189
val_to_int_ptr(n),
152190
info,
153191
)
154192

155-
return b_copy, int_ptr_to_val(info)
193+
return b, int_ptr_to_val(info)
156194

157195
return impl
158196

@@ -283,7 +321,9 @@ def impl(
283321

284322
anorm = tridiagonal_norm(du, d, dl)
285323

286-
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
324+
dl, d, du, du2, IPIV, INFO = _gttrf(
325+
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
326+
)
287327
_solve_check(n, INFO)
288328

289329
X, INFO = _gttrs(
@@ -297,3 +337,48 @@ def impl(
297337
return X
298338

299339
return impl
340+
341+
342+
@numba_funcify.register(LUFactorTridiagonal)
343+
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
344+
overwrite_dl = op.overwrite_dl
345+
overwrite_d = op.overwrite_d
346+
overwrite_du = op.overwrite_du
347+
348+
@numba_njit(cache=False)
349+
def lu_factor_tridiagonal(dl, d, du):
350+
dl, d, du, du2, ipiv, _ = _gttrf(
351+
dl,
352+
d,
353+
du,
354+
overwrite_dl=overwrite_dl,
355+
overwrite_d=overwrite_d,
356+
overwrite_du=overwrite_du,
357+
)
358+
return dl, d, du, du2, ipiv
359+
360+
return lu_factor_tridiagonal
361+
362+
363+
@numba_funcify.register(SolveLUFactorTridiagonal)
364+
def numba_funcify_SolveLUFactorTridiagonal(
365+
op: SolveLUFactorTridiagonal, node, **kwargs
366+
):
367+
overwrite_b = op.overwrite_b
368+
transposed = op.transposed
369+
370+
@numba_njit(cache=False)
371+
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
372+
x, _ = _gttrs(
373+
dl,
374+
d,
375+
du,
376+
du2,
377+
ipiv,
378+
b,
379+
overwrite_b=overwrite_b,
380+
trans=transposed,
381+
)
382+
return x
383+
384+
return solve_lu_factor_tridiagonal

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from collections.abc import Container
22
from copy import copy
33

4+
from pytensor.compile import optdb
45
from pytensor.graph import Constant, graph_inputs
56
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
67
from pytensor.scan.op import Scan
78
from pytensor.scan.rewriting import scan_seqopt1
9+
from pytensor.tensor._linalg.solve.tridiagonal import (
10+
tridiagonal_lu_factor,
11+
tridiagonal_lu_solve,
12+
)
813
from pytensor.tensor.basic import atleast_Nd
914
from pytensor.tensor.blockwise import Blockwise
1015
from pytensor.tensor.elemwise import DimShuffle
@@ -17,18 +22,32 @@
1722
def decompose_A(A, assume_a, check_finite):
1823
if assume_a == "gen":
1924
return lu_factor(A, check_finite=check_finite)
25+
elif assume_a == "tridiagonal":
26+
# We didn't implement check_finite for tridiagonal LU factorization
27+
return tridiagonal_lu_factor(A)
2028
else:
2129
raise NotImplementedError
2230

2331

2432
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
25-
if core_solve_op.assume_a == "gen":
33+
b_ndim = core_solve_op.b_ndim
34+
check_finite = core_solve_op.check_finite
35+
assume_a = core_solve_op.assume_a
36+
if assume_a == "gen":
2637
return lu_solve(
2738
A_decomp,
2839
b,
40+
b_ndim=b_ndim,
2941
trans=transposed,
30-
b_ndim=core_solve_op.b_ndim,
31-
check_finite=core_solve_op.check_finite,
42+
check_finite=check_finite,
43+
)
44+
elif assume_a == "tridiagonal":
45+
# We didn't implement check_finite for tridiagonal LU solve
46+
return tridiagonal_lu_solve(
47+
A_decomp,
48+
b,
49+
b_ndim=b_ndim,
50+
transposed=transposed,
3251
)
3352
else:
3453
raise NotImplementedError
@@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
189208
@register_specialize
190209
@node_rewriter([Blockwise])
191210
def reuse_lu_decomposition_multiple_solves(fgraph, node):
192-
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
211+
return _split_lu_solve_steps(
212+
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
213+
)
193214

194215

195216
@node_rewriter([Scan])
196217
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
197218
return _scan_split_non_sequence_lu_decomposition_solve(
198-
fgraph, node, allowed_assume_a={"gen"}
219+
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
199220
)
200221

201222

@@ -207,3 +228,32 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
207228
"scan_pushout",
208229
position=2,
209230
)
231+
232+
233+
@node_rewriter([Blockwise])
234+
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
235+
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
236+
237+
238+
optdb["specialize"].register(
239+
reuse_lu_decomposition_multiple_solves_jax.__name__,
240+
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
241+
"jax",
242+
use_db_name_as_tag=False,
243+
)
244+
245+
246+
@node_rewriter([Scan])
247+
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
248+
return _scan_split_non_sequence_lu_decomposition_solve(
249+
fgraph, node, allowed_assume_a={"gen"}
250+
)
251+
252+
253+
scan_seqopt1.register(
254+
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
255+
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
256+
"jax",
257+
use_db_name_as_tag=False,
258+
position=2,
259+
)

0 commit comments

Comments
 (0)