6
6
from numpy import ndarray
7
7
from scipy import linalg
8
8
9
+ from pytensor .link .numba .dispatch import numba_funcify
9
10
from pytensor .link .numba .dispatch .basic import numba_njit
10
11
from pytensor .link .numba .dispatch .linalg ._LAPACK import (
11
12
_LAPACK ,
20
21
_solve_check ,
21
22
_trans_char_to_int ,
22
23
)
24
+ from pytensor .tensor ._linalg .solve .tridiagonal import (
25
+ LUFactorTridiagonal ,
26
+ SolveLUFactorTridiagonal ,
27
+ )
23
28
24
29
25
30
@numba_njit
@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
34
39
35
40
36
41
def _gttrf (
37
- dl : ndarray , d : ndarray , du : ndarray
42
+ dl : ndarray ,
43
+ d : ndarray ,
44
+ du : ndarray ,
45
+ overwrite_dl : bool ,
46
+ overwrite_d : bool ,
47
+ overwrite_du : bool ,
38
48
) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
39
49
"""Placeholder for LU factorization of tridiagonal matrix."""
40
50
return # type: ignore
@@ -45,8 +55,12 @@ def gttrf_impl(
45
55
dl : ndarray ,
46
56
d : ndarray ,
47
57
du : ndarray ,
58
+ overwrite_dl : bool ,
59
+ overwrite_d : bool ,
60
+ overwrite_du : bool ,
48
61
) -> Callable [
49
- [ndarray , ndarray , ndarray ], tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]
62
+ [ndarray , ndarray , ndarray , bool , bool , bool ],
63
+ tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ],
50
64
]:
51
65
ensure_lapack ()
52
66
_check_scipy_linalg_matrix (dl , "gttrf" )
@@ -60,12 +74,24 @@ def impl(
60
74
dl : ndarray ,
61
75
d : ndarray ,
62
76
du : ndarray ,
77
+ overwrite_dl : bool ,
78
+ overwrite_d : bool ,
79
+ overwrite_du : bool ,
63
80
) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
64
81
n = np .int32 (d .shape [- 1 ])
65
82
ipiv = np .empty (n , dtype = np .int32 )
66
83
du2 = np .empty (n - 2 , dtype = dtype )
67
84
info = val_to_int_ptr (0 )
68
85
86
+ if not overwrite_dl or not dl .flags .f_contiguous :
87
+ dl = dl .copy ()
88
+
89
+ if not overwrite_d or not d .flags .f_contiguous :
90
+ d = d .copy ()
91
+
92
+ if not overwrite_du or not du .flags .f_contiguous :
93
+ du = du .copy ()
94
+
69
95
numba_gttrf (
70
96
val_to_int_ptr (n ),
71
97
dl .view (w_type ).ctypes ,
@@ -133,10 +159,23 @@ def impl(
133
159
nrhs = 1 if b .ndim == 1 else int (b .shape [- 1 ])
134
160
info = val_to_int_ptr (0 )
135
161
136
- if overwrite_b and b .flags .f_contiguous :
137
- b_copy = b
138
- else :
139
- b_copy = _copy_to_fortran_order_even_if_1d (b )
162
+ if not overwrite_b or not b .flags .f_contiguous :
163
+ b = _copy_to_fortran_order_even_if_1d (b )
164
+
165
+ if not dl .flags .f_contiguous :
166
+ dl = dl .copy ()
167
+
168
+ if not d .flags .f_contiguous :
169
+ d = d .copy ()
170
+
171
+ if not du .flags .f_contiguous :
172
+ du = du .copy ()
173
+
174
+ if not du2 .flags .f_contiguous :
175
+ du2 = du2 .copy ()
176
+
177
+ if not ipiv .flags .f_contiguous :
178
+ ipiv = ipiv .copy ()
140
179
141
180
numba_gttrs (
142
181
val_to_int_ptr (_trans_char_to_int (trans )),
@@ -147,12 +186,12 @@ def impl(
147
186
du .view (w_type ).ctypes ,
148
187
du2 .view (w_type ).ctypes ,
149
188
ipiv .ctypes ,
150
- b_copy .view (w_type ).ctypes ,
189
+ b .view (w_type ).ctypes ,
151
190
val_to_int_ptr (n ),
152
191
info ,
153
192
)
154
193
155
- return b_copy , int_ptr_to_val (info )
194
+ return b , int_ptr_to_val (info )
156
195
157
196
return impl
158
197
@@ -283,7 +322,9 @@ def impl(
283
322
284
323
anorm = tridiagonal_norm (du , d , dl )
285
324
286
- dl , d , du , du2 , IPIV , INFO = _gttrf (dl , d , du )
325
+ dl , d , du , du2 , IPIV , INFO = _gttrf (
326
+ dl , d , du , overwrite_dl = True , overwrite_d = True , overwrite_du = True
327
+ )
287
328
_solve_check (n , INFO )
288
329
289
330
X , INFO = _gttrs (
@@ -297,3 +338,48 @@ def impl(
297
338
return X
298
339
299
340
return impl
341
+
342
+
343
+ @numba_funcify .register (LUFactorTridiagonal )
344
+ def numba_funcify_LUFactorTridiagonal (op : LUFactorTridiagonal , node , ** kwargs ):
345
+ overwrite_dl = op .overwrite_dl
346
+ overwrite_d = op .overwrite_d
347
+ overwrite_du = op .overwrite_du
348
+
349
+ @numba_njit (cache = False )
350
+ def lu_factor_tridiagonal (dl , d , du ):
351
+ dl , d , du , du2 , ipiv , _ = _gttrf (
352
+ dl ,
353
+ d ,
354
+ du ,
355
+ overwrite_dl = overwrite_dl ,
356
+ overwrite_d = overwrite_d ,
357
+ overwrite_du = overwrite_du ,
358
+ )
359
+ return dl , d , du , du2 , ipiv
360
+
361
+ return lu_factor_tridiagonal
362
+
363
+
364
+ @numba_funcify .register (SolveLUFactorTridiagonal )
365
+ def numba_funcify_SolveLUFactorTridiagonal (
366
+ op : SolveLUFactorTridiagonal , node , ** kwargs
367
+ ):
368
+ overwrite_b = op .overwrite_b
369
+ transposed = op .transposed
370
+
371
+ @numba_njit (cache = False )
372
+ def solve_lu_factor_tridiagonal (dl , d , du , du2 , ipiv , b ):
373
+ x , _ = _gttrs (
374
+ dl ,
375
+ d ,
376
+ du ,
377
+ du2 ,
378
+ ipiv ,
379
+ b ,
380
+ overwrite_b = overwrite_b ,
381
+ trans = transposed ,
382
+ )
383
+ return x
384
+
385
+ return solve_lu_factor_tridiagonal
0 commit comments