Skip to content

Commit 2282161

Browse files
Implement suggestions
1 parent bbf3141 commit 2282161

File tree

2 files changed

+15
-26
lines changed

2 files changed

+15
-26
lines changed

pytensor/tensor/slinalg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,10 @@ def _to_banded_form(A, kl, ku):
16991699
return ab
17001700

17011701

1702+
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
1703+
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
1704+
1705+
17021706
class BandedDot(Op):
17031707
__props__ = ("lower_diags", "upper_diags")
17041708
gufunc_signature = "(m,n),(n)->(n)"
@@ -1726,7 +1730,7 @@ def perform(self, node, inputs, outputs_storage):
17261730

17271731
A_banded = _to_banded_form(A, kl, ku)
17281732

1729-
fn = scipy_linalg.get_blas_funcs("gbmv", (A, b))
1733+
fn = _dgbmv if A.dtype == "float64" else _sgbmv
17301734
outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b)
17311735

17321736

tests/tensor/test_slinalg.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,42 +1081,22 @@ def test_banded_dot(A_shape, kl, ku):
10811081
res = banded_dot(A, b, kl, ku)
10821082
res_2 = A @ b
10831083

1084-
fn = function([A, b], [res, res_2])
1084+
fn = function([A, b], [res, res_2], trust_input=True)
10851085
assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes)
10861086

10871087
x_val, x2_val = fn(A_val, b_val)
10881088

10891089
np.testing.assert_allclose(x_val, x2_val)
10901090

10911091

1092+
@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str)
10921093
@pytest.mark.parametrize(
10931094
"A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"]
10941095
)
10951096
@pytest.mark.parametrize(
10961097
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
10971098
)
1098-
def test_banded_dot_perf(A_shape, kl, ku, benchmark):
1099-
rng = np.random.default_rng()
1100-
1101-
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1102-
b_val = rng.normal(size=(A_shape[-1],))
1103-
1104-
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
1105-
b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype)
1106-
1107-
res = banded_dot(A, b, kl, ku)
1108-
fn = function([A, b], res, trust_input=True)
1109-
1110-
benchmark(fn, A_val, b_val)
1111-
1112-
1113-
@pytest.mark.parametrize(
1114-
"A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"]
1115-
)
1116-
@pytest.mark.parametrize(
1117-
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
1118-
)
1119-
def test_dot_perf(A_shape, kl, ku, benchmark):
1099+
def test_banded_dot_perf(op, A_shape, kl, ku, benchmark):
11201100
rng = np.random.default_rng()
11211101

11221102
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
@@ -1125,7 +1105,12 @@ def test_dot_perf(A_shape, kl, ku, benchmark):
11251105
A = pt.tensor("A", shape=A_val.shape)
11261106
b = pt.tensor("b", shape=b_val.shape)
11271107

1128-
res = A @ b
1129-
fn = function([A, b], res)
1108+
if op == "dot":
1109+
f = pt.dot
1110+
elif op == "banded_dot":
1111+
f = functools.partial(banded_dot, lower_diags=kl, upper_diags=ku)
1112+
1113+
res = f(A, b)
1114+
fn = function([A, b], res, trust_input=True)
11301115

11311116
benchmark(fn, A_val, b_val)

0 commit comments

Comments
 (0)