16
16
17
17
18
18
@numba_njit (inline = "always" )
19
- def A_to_banded (A : np .ndarray , kl : int , ku : int , order = "C" ) -> np .ndarray :
19
+ def A_to_banded (A : np .ndarray , kl : int , ku : int ) -> np .ndarray :
20
20
m , n = A .shape
21
- if order == "C" :
22
- A_banded = np . zeros (( kl + ku + 1 , n ), dtype = A . dtype )
23
- else :
24
- A_banded = np .zeros ((n , kl + ku + 1 ), dtype = A .dtype ).T
21
+
22
+ # This matrix is build backwards then transposed to get it into Fortran order
23
+ # (order="F" is not allowed in Numba land)
24
+ A_banded = np .zeros ((n , kl + ku + 1 ), dtype = A .dtype ).T
25
25
26
26
for i , k in enumerate (range (ku , - kl - 1 , - 1 )):
27
27
if k >= 0 :
@@ -39,7 +39,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any:
39
39
"""
40
40
fn = linalg .get_blas_funcs ("gbmv" , (A , x ))
41
41
m , n = A .shape
42
- A_banded = A_to_banded (A , kl = kl , ku = ku , order = "F" )
42
+ A_banded = A_to_banded (A , kl = kl , ku = ku )
43
43
44
44
return fn (m = m , n = n , kl = kl , ku = ku , alpha = 1 , a = A_banded , x = x )
45
45
@@ -58,7 +58,7 @@ def dot_banded_impl(
58
58
def impl (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> np .ndarray :
59
59
m , n = A .shape
60
60
61
- A_banded = A_to_banded (A , kl = kl , ku = ku , order = "F" )
61
+ A_banded = A_to_banded (A , kl = kl , ku = ku )
62
62
63
63
TRANS = val_to_int_ptr (ord ("N" ))
64
64
M = val_to_int_ptr (m )
0 commit comments