Skip to content

Commit 481814f

Browse files
Remove order argument from numba A_to_banded
1 parent 5754f93 commit 481814f

File tree

1 file changed

+7
-7
lines changed
  • pytensor/link/numba/dispatch/linalg/dot

1 file changed

+7
-7
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717

1818
@numba_njit(inline="always")
19-
def A_to_banded(A: np.ndarray, kl: int, ku: int, order="C") -> np.ndarray:
19+
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
2020
m, n = A.shape
21-
if order == "C":
22-
A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype)
23-
else:
24-
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T
21+
22+
# This matrix is build backwards then transposed to get it into Fortran order
23+
# (order="F" is not allowed in Numba land)
24+
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T
2525

2626
for i, k in enumerate(range(ku, -kl - 1, -1)):
2727
if k >= 0:
@@ -39,7 +39,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any:
3939
"""
4040
fn = linalg.get_blas_funcs("gbmv", (A, x))
4141
m, n = A.shape
42-
A_banded = A_to_banded(A, kl=kl, ku=ku, order="F")
42+
A_banded = A_to_banded(A, kl=kl, ku=ku)
4343

4444
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
4545

@@ -58,7 +58,7 @@ def dot_banded_impl(
5858
def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
5959
m, n = A.shape
6060

61-
A_banded = A_to_banded(A, kl=kl, ku=ku, order="F")
61+
A_banded = A_to_banded(A, kl=kl, ku=ku)
6262

6363
TRANS = val_to_int_ptr(ord("N"))
6464
M = val_to_int_ptr(m)

0 commit comments

Comments
 (0)