Skip to content

Commit c396a8a

Browse files
Adjust numba test
1 parent 30fece4 commit c396a8a

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from typing import Literal
3+
from typing import cast as typing_cast
34

45
import numpy as np
56
import pytest
@@ -724,25 +725,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724725
np.testing.assert_allclose(b_val_not_contig, b_val)
725726

726727

727-
@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}")
728-
def test_banded_dot(stride):
728+
def test_banded_dot():
729729
rng = np.random.default_rng()
730730

731+
A = pt.tensor("A", shape=(10, 10), dtype=config.floatX)
731732
A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)
732733

733-
x_shape = (10 * abs(stride),)
734-
x_val = rng.normal(size=x_shape).astype(config.floatX)
735-
x_val = x_val[::stride]
736-
737-
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
738-
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
734+
x = pt.tensor("x", shape=(10,), dtype=config.floatX)
735+
x_val = rng.normal(size=(10,)).astype(config.floatX)
739736

740737
output = banded_dot(A, x, upper_diags=1, lower_diags=1)
741738

742-
compare_numba_and_py(
739+
fn, _ = compare_numba_and_py(
743740
[A, x],
744741
output,
745742
test_inputs=[A_val, x_val],
746743
numba_mode=numba_inplace_mode,
747744
eval_obj_mode=False,
748745
)
746+
747+
for stride in [2, -1, -2]:
748+
x_shape = (10 * abs(stride),)
749+
x_val = rng.normal(size=x_shape).astype(config.floatX)
750+
x_val = x_val[::stride]
751+
752+
nb_output = typing_cast(np.ndarray, fn(A_val, x_val))
753+
expected = A_val @ x_val
754+
755+
np.testing.assert_allclose(
756+
nb_output,
757+
expected,
758+
strict=True,
759+
err_msg=f"Test failed for stride = {stride}",
760+
)

0 commit comments

Comments
 (0)