Skip to content

Commit 30fece4

Browse files
Incorporate feedback
1 parent 481814f commit 30fece4

File tree

3 files changed

+5
-18
lines changed

3 files changed

+5
-18
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
5959
m, n = A.shape
6060

6161
A_banded = A_to_banded(A, kl=kl, ku=ku)
62+
stride = x.strides[0] // x.itemsize
6263

6364
TRANS = val_to_int_ptr(ord("N"))
6465
M = val_to_int_ptr(m)
@@ -69,7 +70,8 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
6970
KU = val_to_int_ptr(ku)
7071

7172
ALPHA = np.array(1.0, dtype=dtype)
72-
INCX = val_to_int_ptr(x.strides[0] // x.itemsize)
73+
74+
INCX = val_to_int_ptr(stride)
7375
BETA = np.array(0.0, dtype=dtype)
7476
Y = np.empty(m, dtype=dtype)
7577
INCY = val_to_int_ptr(1)

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,9 +1722,7 @@ def L_op(self, inputs, outputs, output_grads) -> list[Variable]:
17221722
(G_bar,) = output_grads
17231723

17241724
A_bar = pt.outer(G_bar, x.T)
1725-
x_bar = banded_dot(
1726-
A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags
1727-
)
1725+
x_bar = self(A.T, G_bar)
17281726

17291727
return [A_bar, x_bar]
17301728

tests/link/numba/test_slinalg.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724724
np.testing.assert_allclose(b_val_not_contig, b_val)
725725

726726

727-
@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}")
727+
@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}")
728728
def test_banded_dot(stride):
729729
rng = np.random.default_rng()
730730

@@ -743,19 +743,6 @@ def test_banded_dot(stride):
743743
[A, x],
744744
output,
745745
test_inputs=[A_val, x_val],
746-
inplace=True,
747-
numba_mode=numba_inplace_mode,
748-
eval_obj_mode=False,
749-
)
750-
751-
# Test non-contiguous x input
752-
x_val = rng.normal(size=(20,))[::2]
753-
754-
compare_numba_and_py(
755-
[A, x],
756-
output,
757-
test_inputs=[A_val, x_val],
758-
inplace=True,
759746
numba_mode=numba_inplace_mode,
760747
eval_obj_mode=False,
761748
)

0 commit comments

Comments
 (0)