|
1 | 1 | import re
|
2 | 2 | from typing import Literal
|
| 3 | +from typing import cast as typing_cast |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 | import pytest
|
@@ -724,25 +725,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
|
724 | 725 | np.testing.assert_allclose(b_val_not_contig, b_val)
|
725 | 726 |
|
726 | 727 |
|
727 |
| -@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}") |
728 |
| -def test_banded_dot(stride): |
| 728 | +def test_banded_dot(): |
729 | 729 | rng = np.random.default_rng()
|
730 | 730 |
|
| 731 | + A = pt.tensor("A", shape=(10, 10), dtype=config.floatX) |
731 | 732 | A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)
|
732 | 733 |
|
733 |
| - x_shape = (10 * abs(stride),) |
734 |
| - x_val = rng.normal(size=x_shape).astype(config.floatX) |
735 |
| - x_val = x_val[::stride] |
736 |
| - |
737 |
| - A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) |
738 |
| - x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) |
| 734 | + x = pt.tensor("x", shape=(10,), dtype=config.floatX) |
| 735 | + x_val = rng.normal(size=(10,)).astype(config.floatX) |
739 | 736 |
|
740 | 737 | output = banded_dot(A, x, upper_diags=1, lower_diags=1)
|
741 | 738 |
|
742 |
| - compare_numba_and_py( |
| 739 | + fn, _ = compare_numba_and_py( |
743 | 740 | [A, x],
|
744 | 741 | output,
|
745 | 742 | test_inputs=[A_val, x_val],
|
746 | 743 | numba_mode=numba_inplace_mode,
|
747 | 744 | eval_obj_mode=False,
|
748 | 745 | )
|
| 746 | + |
| 747 | + for stride in [2, -1, -2]: |
| 748 | + x_shape = (10 * abs(stride),) |
| 749 | + x_val = rng.normal(size=x_shape).astype(config.floatX) |
| 750 | + x_val = x_val[::stride] |
| 751 | + |
| 752 | + nb_output = typing_cast(np.ndarray, fn(A_val, x_val)) |
| 753 | + expected = A_val @ x_val |
| 754 | + |
| 755 | + np.testing.assert_allclose( |
| 756 | + nb_output, |
| 757 | + expected, |
| 758 | + strict=True, |
| 759 | + err_msg=f"Test failed for stride = {stride}", |
| 760 | + ) |
0 commit comments