From bbf3141d31d5ed913d7a5226283fa72b50cc6574 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 23 May 2025 16:38:14 +0800 Subject: [PATCH 01/25] Naive implementation, do not merge Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/slinalg.py | 90 ++++++++++++++++++++++++++++++++++++ tests/tensor/test_slinalg.py | 78 +++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bbdc9cbba7..ff94170386 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1669,6 +1669,95 @@ def block_diag(*matrices: TensorVariable): return _block_diagonal_matrix(*matrices) +def _to_banded_form(A, kl, ku): + """ + Convert a full matrix A to LAPACK banded form for gbmv. + + Parameters + ---------- + A: np.ndarray + (m, n) banded matrix with nonzero values on the diagonals + kl: int + Number of nonzero lower diagonals of A + ku: int + Number of nonzero upper diagonals of A + + Returns + ------- + ab: np.ndarray + (kl + ku + 1, n) banded matrix suitable for LAPACK + """ + A = np.asarray(A) + m, n = A.shape + ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C") + + for i, k in enumerate(range(ku, -kl - 1, -1)): + padding = (k, 0) if k >= 0 else (0, -k) + diag = np.pad(np.diag(A, k=k), padding) + ab[i, :] = diag + + return ab + + +class BandedDot(Op): + __props__ = ("lower_diags", "upper_diags") + gufunc_signature = "(m,n),(n)->(n)" + + def __init__(self, lower_diags, upper_diags): + self.lower_diags = lower_diags + self.upper_diags = upper_diags + + def make_node(self, A, b): + A = as_tensor_variable(A) + B = as_tensor_variable(b) + + out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) + output = b.type().astype(out_dtype) + + return pytensor.graph.basic.Apply(self, [A, B], [output]) + + def perform(self, node, inputs, outputs_storage): + A, b = inputs + m, n = A.shape + alpha = 1 + + kl = self.lower_diags + ku = self.upper_diags + + A_banded = _to_banded_form(A, kl, ku) + + fn = scipy_linalg.get_blas_funcs("gbmv", (A, b)) + outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b) + + +def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int): + """ + Specialized matrix-vector multiplication for cases when A is a banded matrix + + No type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. This will lead + to incorrect results if A is not actually a banded matrix. + + Unlike dot, this function is only valid if b is a vector. + + Parameters + ---------- + A: Tensorlike + Matrix to perform banded dot on. + b: Tensorlike + Vector to perform banded dot on. + lower_diags: int + Number of nonzero lower diagonals of A + upper_diags: int + Number of nonzero upper diagonals of A + + Returns + ------- + out: Tensor + The matrix multiplication result + """ + return Blockwise(BandedDot(lower_diags, upper_diags))(A, b) + + __all__ = [ "cholesky", "solve", @@ -1683,4 +1772,5 @@ def block_diag(*matrices: TensorVariable): "lu", "lu_factor", "lu_solve", + "banded_dot", ] diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f18f514244..98e5464402 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -12,11 +12,13 @@ from pytensor.graph.basic import equal_computations from pytensor.tensor import TensorVariable from pytensor.tensor.slinalg import ( + BandedDot, Cholesky, CholeskySolve, Solve, SolveBase, SolveTriangular, + banded_dot, block_diag, cho_solve, cholesky, @@ -1051,3 +1053,79 @@ def test_block_diagonal_blockwise(): B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) result = block_diag(A, B).eval() assert result.shape == (10, batch_size, 6, 6) + + +def _make_banded_A(A, kl, ku): + diag_idxs = range(-kl, ku + 1) + diags = (np.diag(A, k=k) for k in diag_idxs) + return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags)) + + +@pytest.mark.parametrize( + "A_shape", + [ + (10, 10), + ], +) +@pytest.mark.parametrize( + "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] +) +def test_banded_dot(A_shape, kl, ku): + rng = np.random.default_rng() + + A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) + b_val = rng.normal(size=(A_shape[-1],)) + + A = pt.tensor("A", shape=A_val.shape) + b = pt.tensor("b", shape=b_val.shape) + res = banded_dot(A, b, kl, ku) + res_2 = A @ b + + fn = function([A, b], [res, res_2]) + assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes) + + x_val, x2_val = fn(A_val, b_val) + + np.testing.assert_allclose(x_val, x2_val) + + +@pytest.mark.parametrize( + "A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"] +) +@pytest.mark.parametrize( + "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] +) +def test_banded_dot_perf(A_shape, kl, ku, benchmark): + rng = np.random.default_rng() + + A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) + b_val = rng.normal(size=(A_shape[-1],)) + + A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) + b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype) + + res = banded_dot(A, b, kl, ku) + fn = function([A, b], res, trust_input=True) + + benchmark(fn, A_val, b_val) + + +@pytest.mark.parametrize( + "A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"] +) +@pytest.mark.parametrize( + "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] +) +def test_dot_perf(A_shape, kl, ku, benchmark): + rng = np.random.default_rng() + + A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) + b_val = rng.normal(size=(A_shape[-1],)) + + A = pt.tensor("A", shape=A_val.shape) + b = pt.tensor("b", shape=b_val.shape) + + res = A @ b + fn = function([A, b], res) + + benchmark(fn, A_val, b_val) From 22821616bdabf304862021aff2e25dd5f18ae913 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 23 May 2025 17:41:09 +0800 Subject: [PATCH 02/25] Implement suggestions --- pytensor/tensor/slinalg.py | 6 +++++- tests/tensor/test_slinalg.py | 35 ++++++++++------------------------- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index ff94170386..74c592afb9 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1699,6 +1699,10 @@ def _to_banded_form(A, kl, ku): return ab +_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64") +_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") + + class BandedDot(Op): __props__ = ("lower_diags", "upper_diags") gufunc_signature = "(m,n),(n)->(n)" @@ -1726,7 +1730,7 @@ def perform(self, node, inputs, outputs_storage): A_banded = _to_banded_form(A, kl, ku) - fn = scipy_linalg.get_blas_funcs("gbmv", (A, b)) + fn = _dgbmv if A.dtype == "float64" else _sgbmv outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 98e5464402..ba4023463b 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1081,7 +1081,7 @@ def test_banded_dot(A_shape, kl, ku): res = banded_dot(A, b, kl, ku) res_2 = A @ b - fn = function([A, b], [res, res_2]) + fn = function([A, b], [res, res_2], trust_input=True) assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes) x_val, x2_val = fn(A_val, b_val) @@ -1089,34 +1089,14 @@ def test_banded_dot(A_shape, kl, ku): np.testing.assert_allclose(x_val, x2_val) +@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str) @pytest.mark.parametrize( "A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"] ) @pytest.mark.parametrize( "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] ) -def test_banded_dot_perf(A_shape, kl, ku, benchmark): - rng = np.random.default_rng() - - A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) - b_val = rng.normal(size=(A_shape[-1],)) - - A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) - b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype) - - res = banded_dot(A, b, kl, ku) - fn = function([A, b], res, trust_input=True) - - benchmark(fn, A_val, b_val) - - -@pytest.mark.parametrize( - "A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"] -) -@pytest.mark.parametrize( - "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] -) -def test_dot_perf(A_shape, kl, ku, benchmark): +def test_banded_dot_perf(op, A_shape, kl, ku, benchmark): rng = np.random.default_rng() A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) @@ -1125,7 +1105,12 @@ def test_dot_perf(A_shape, kl, ku, benchmark): A = pt.tensor("A", shape=A_val.shape) b = pt.tensor("b", shape=b_val.shape) - res = A @ b - fn = function([A, b], res) + if op == "dot": + f = pt.dot + elif op == "banded_dot": + f = functools.partial(banded_dot, lower_diags=kl, upper_diags=ku) + + res = f(A, b) + fn = function([A, b], res, trust_input=True) benchmark(fn, A_val, b_val) From ae8eff67a016028e864826cbd80fe74957202874 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 23 May 2025 17:45:58 +0800 Subject: [PATCH 03/25] Simplify perf test --- tests/tensor/test_slinalg.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index ba4023463b..cf0b138163 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1091,15 +1091,14 @@ def test_banded_dot(A_shape, kl, ku): @pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str) @pytest.mark.parametrize( - "A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"] -) -@pytest.mark.parametrize( - "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] + "A_shape", + [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)], + ids=["10", "100", "1000", "10_000"], ) -def test_banded_dot_perf(op, A_shape, kl, ku, benchmark): +def test_banded_dot_perf(op, A_shape, benchmark): rng = np.random.default_rng() - A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) + A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1) b_val = rng.normal(size=(A_shape[-1],)) A = pt.tensor("A", shape=A_val.shape) @@ -1108,7 +1107,7 @@ def test_banded_dot_perf(op, A_shape, kl, ku, benchmark): if op == "dot": f = pt.dot elif op == "banded_dot": - f = functools.partial(banded_dot, lower_diags=kl, upper_diags=ku) + f = functools.partial(banded_dot, lower_diags=1, upper_diags=1) res = f(A, b) fn = function([A, b], res, trust_input=True) From b2f68a525ee09981a8a31923673b27537f94a7ab Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 23 May 2025 17:55:56 +0800 Subject: [PATCH 04/25] float32 compat in tests --- tests/tensor/test_slinalg.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index cf0b138163..f8bd7d8e43 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1073,11 +1073,11 @@ def _make_banded_A(A, kl, ku): def test_banded_dot(A_shape, kl, ku): rng = np.random.default_rng() - A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku) - b_val = rng.normal(size=(A_shape[-1],)) + A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX) + b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) - A = pt.tensor("A", shape=A_val.shape) - b = pt.tensor("b", shape=b_val.shape) + A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) + b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype) res = banded_dot(A, b, kl, ku) res_2 = A @ b @@ -1098,11 +1098,11 @@ def test_banded_dot(A_shape, kl, ku): def test_banded_dot_perf(op, A_shape, benchmark): rng = np.random.default_rng() - A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1) - b_val = rng.normal(size=(A_shape[-1],)) + A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX) + b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) - A = pt.tensor("A", shape=A_val.shape) - b = pt.tensor("b", shape=b_val.shape) + A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) + b = pt.tensor("b", shape=b_val.shape, dtype=A_val.dtype) if op == "dot": f = pt.dot From e64d4d30414c2738d30e09d516d701c6fb386874 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 23 May 2025 18:21:13 +0800 Subject: [PATCH 05/25] Remove np.pad --- pytensor/tensor/slinalg.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 74c592afb9..f89b4ae4fc 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1692,9 +1692,8 @@ def _to_banded_form(A, kl, ku): ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C") for i, k in enumerate(range(ku, -kl - 1, -1)): - padding = (k, 0) if k >= 0 else (0, -k) - diag = np.pad(np.diag(A, k=k), padding) - ab[i, :] = diag + col_slice = slice(k, None) if k >= 0 else slice(None, n + k) + ab[i, col_slice] = np.diag(A, k=k) return ab From c979e9d532272607fe90631160ae1c5c72a7ef1f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 23 May 2025 18:31:54 +0800 Subject: [PATCH 06/25] set dtype correctly --- pytensor/tensor/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f89b4ae4fc..9fbf6ead54 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1715,7 +1715,7 @@ def make_node(self, A, b): B = as_tensor_variable(b) out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) - output = b.type().astype(out_dtype) + output = b.type.clone(dtype=out_dtype)() return pytensor.graph.basic.Apply(self, [A, B], [output]) From f1066a9c5c2b315237beba4a7a2373c1fc978010 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 23 May 2025 18:57:07 +0800 Subject: [PATCH 07/25] fix signature, add infer_shape --- pytensor/tensor/slinalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 9fbf6ead54..11c1b7ae0f 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1704,7 +1704,7 @@ def _to_banded_form(A, kl, ku): class BandedDot(Op): __props__ = ("lower_diags", "upper_diags") - gufunc_signature = "(m,n),(n)->(n)" + gufunc_signature = "(m,n),(n)->(m)" def __init__(self, lower_diags, upper_diags): self.lower_diags = lower_diags @@ -1719,6 +1719,9 @@ def make_node(self, A, b): return pytensor.graph.basic.Apply(self, [A, B], [output]) + def infer_shape(self, fgraph, nodes, shapes): + return [shapes[0][:-1]] + def perform(self, node, inputs, outputs_storage): A, b = inputs m, n = A.shape From 0302faca68e582bd18b5ff000c57da1bb9e92f3f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 23 May 2025 18:57:25 +0800 Subject: [PATCH 08/25] micro-optimizations --- pytensor/tensor/slinalg.py | 44 ++++++-------------------------------- 1 file changed, 7 insertions(+), 37 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 11c1b7ae0f..fc259ece12 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,6 +6,7 @@ import numpy as np import scipy.linalg as scipy_linalg +from numpy import diag, zeros from numpy.exceptions import ComplexWarning import pytensor @@ -1669,39 +1670,6 @@ def block_diag(*matrices: TensorVariable): return _block_diagonal_matrix(*matrices) -def _to_banded_form(A, kl, ku): - """ - Convert a full matrix A to LAPACK banded form for gbmv. - - Parameters - ---------- - A: np.ndarray - (m, n) banded matrix with nonzero values on the diagonals - kl: int - Number of nonzero lower diagonals of A - ku: int - Number of nonzero upper diagonals of A - - Returns - ------- - ab: np.ndarray - (kl + ku + 1, n) banded matrix suitable for LAPACK - """ - A = np.asarray(A) - m, n = A.shape - ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C") - - for i, k in enumerate(range(ku, -kl - 1, -1)): - col_slice = slice(k, None) if k >= 0 else slice(None, n + k) - ab[i, col_slice] = np.diag(A, k=k) - - return ab - - -_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64") -_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") - - class BandedDot(Op): __props__ = ("lower_diags", "upper_diags") gufunc_signature = "(m,n),(n)->(m)" @@ -1725,15 +1693,17 @@ def infer_shape(self, fgraph, nodes, shapes): def perform(self, node, inputs, outputs_storage): A, b = inputs m, n = A.shape - alpha = 1 kl = self.lower_diags ku = self.upper_diags - A_banded = _to_banded_form(A, kl, ku) + A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="C") + + for i, k in enumerate(range(ku, -kl - 1, -1)): + A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k) - fn = _dgbmv if A.dtype == "float64" else _sgbmv - outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b) + fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) + outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=b) def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int): From f47d88b0b8134e5eec9875730ed8e077d9d2c85a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 12:41:29 +0800 Subject: [PATCH 09/25] Rename b to x, matching BLAS docs --- pytensor/tensor/slinalg.py | 20 ++++++++++---------- tests/tensor/test_slinalg.py | 24 ++++++++++++------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index fc259ece12..bb9687f54a 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1678,20 +1678,20 @@ def __init__(self, lower_diags, upper_diags): self.lower_diags = lower_diags self.upper_diags = upper_diags - def make_node(self, A, b): + def make_node(self, A, x): A = as_tensor_variable(A) - B = as_tensor_variable(b) + x = as_tensor_variable(x) - out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) - output = b.type.clone(dtype=out_dtype)() + out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype) + output = x.type.clone(dtype=out_dtype)() - return pytensor.graph.basic.Apply(self, [A, B], [output]) + return pytensor.graph.basic.Apply(self, [A, x], [output]) def infer_shape(self, fgraph, nodes, shapes): return [shapes[0][:-1]] def perform(self, node, inputs, outputs_storage): - A, b = inputs + A, x = inputs m, n = A.shape kl = self.lower_diags @@ -1703,10 +1703,10 @@ def perform(self, node, inputs, outputs_storage): A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k) fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) - outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=b) + outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) -def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int): +def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int): """ Specialized matrix-vector multiplication for cases when A is a banded matrix @@ -1719,7 +1719,7 @@ def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int) ---------- A: Tensorlike Matrix to perform banded dot on. - b: Tensorlike + x: Tensorlike Vector to perform banded dot on. lower_diags: int Number of nonzero lower diagonals of A @@ -1731,7 +1731,7 @@ def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int) out: Tensor The matrix multiplication result """ - return Blockwise(BandedDot(lower_diags, upper_diags))(A, b) + return Blockwise(BandedDot(lower_diags, upper_diags))(A, x) __all__ = [ diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f8bd7d8e43..9fa8df768e 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1074,19 +1074,19 @@ def test_banded_dot(A_shape, kl, ku): rng = np.random.default_rng() A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX) - b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) + x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) - b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype) - res = banded_dot(A, b, kl, ku) - res_2 = A @ b + x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) + res = banded_dot(A, x, kl, ku) + res_2 = A @ x - fn = function([A, b], [res, res_2], trust_input=True) + fn = function([A, x], [res, res_2], trust_input=True) assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes) - x_val, x2_val = fn(A_val, b_val) + out_val, out_2_val = fn(A_val, x_val) - np.testing.assert_allclose(x_val, x2_val) + np.testing.assert_allclose(out_val, out_2_val) @pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str) @@ -1099,17 +1099,17 @@ def test_banded_dot_perf(op, A_shape, benchmark): rng = np.random.default_rng() A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX) - b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) + x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) - b = pt.tensor("b", shape=b_val.shape, dtype=A_val.dtype) + x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) if op == "dot": f = pt.dot elif op == "banded_dot": f = functools.partial(banded_dot, lower_diags=1, upper_diags=1) - res = f(A, b) - fn = function([A, b], res, trust_input=True) + res = f(A, x) + fn = function([A, x], res, trust_input=True) - benchmark(fn, A_val, b_val) + benchmark(fn, A_val, x_val) From 157345c76c99aabd0c0378775c687624b872cd4f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 16:36:12 +0800 Subject: [PATCH 10/25] Add numba dispatch for banded_dot --- pytensor/link/numba/dispatch/basic.py | 2 +- pytensor/link/numba/dispatch/linalg/_BLAS.py | 64 +++++++++++++ .../numba/dispatch/linalg/dot/__init__.py | 0 .../link/numba/dispatch/linalg/dot/banded.py | 93 +++++++++++++++++++ pytensor/link/numba/dispatch/slinalg.py | 18 ++++ tests/link/numba/test_slinalg.py | 23 +++++ 6 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/numba/dispatch/linalg/_BLAS.py create mode 100644 pytensor/link/numba/dispatch/linalg/dot/__init__.py create mode 100644 pytensor/link/numba/dispatch/linalg/dot/banded.py diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 845d6afc7a..2963c030d1 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs): message=( "(\x1b\\[1m)*" # ansi escape code for bold text "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor|banded_dot)" ' "as it uses dynamic globals" ), category=NumbaWarning, diff --git a/pytensor/link/numba/dispatch/linalg/_BLAS.py b/pytensor/link/numba/dispatch/linalg/_BLAS.py new file mode 100644 index 0000000000..e6416637cf --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/_BLAS.py @@ -0,0 +1,64 @@ +import ctypes + +from numba.core.extending import get_cython_function_address +from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _get_float_pointer_for_dtype, + _ptr_int, +) + + +def _get_blas_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name) + + return lapack_ptr, float_pointer + + +class _BLAS: + """ + Functions to return type signatures for wrapped BLAS functions. + + Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + ensure_blas() + + @classmethod + def numba_xgbmv(cls, dtype): + """ + xGBMV performs one of the following matrix operations: + + y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y + + Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku + super-diagonals. + """ + + blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # M + _ptr_int, # N + _ptr_int, # KL + _ptr_int, # KU + float_pointer, # ALPHA + float_pointer, # A + _ptr_int, # LDA + float_pointer, # X + _ptr_int, # INCX + float_pointer, # BETA + float_pointer, # Y + _ptr_int, # INCY + ) + + return functype(blas_ptr) diff --git a/pytensor/link/numba/dispatch/linalg/dot/__init__.py b/pytensor/link/numba/dispatch/linalg/dot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py new file mode 100644 index 0000000000..b6458c72bc --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -0,0 +1,93 @@ +from collections.abc import Callable + +import numpy as np +from numba import njit as numba_njit +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_blas, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _get_underlying_float, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix + + +@numba_njit(inline="always") +def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray: + m, n = A.shape + A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype) + + for i, k in enumerate(range(ku, -kl - 1, -1)): + if k >= 0: + A_banded[i, k:] = np.diag(A, k=k) + else: + A_banded[i, : n + k] = np.diag(A, k=k) + + return A_banded + + +def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: + """ + Thin wrapper around gmbv. This code will only be called if njit is disabled globally + (e.g. during testing) + """ + fn = linalg.get_blas_funcs("gbmv", (A, x)) + m, n = A.shape + A_banded = A_to_banded(A, kl=kl, ku=ku) + + return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) + + +@overload(_dot_banded) +def dot_banded_impl( + A: np.ndarray, x: np.ndarray, kl: int, ku: int +) -> Callable[[np.ndarray, np.ndarray, int, int], np.ndarray]: + ensure_lapack() + ensure_blas() + _check_scipy_linalg_matrix(A, "dot_banded") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gbmv = _BLAS().numba_xgbmv(dtype) + + def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: + m, n = A.shape + + # TODO: Can we avoid this copy? + A_banded = A_to_banded(A, kl=kl, ku=ku) + A_banded = _copy_to_fortran_order(A_banded) + + TRANS = val_to_int_ptr(ord("N")) + M = val_to_int_ptr(m) + N = val_to_int_ptr(n) + LDA = val_to_int_ptr(A_banded.shape[0]) + + KL = val_to_int_ptr(kl) + KU = val_to_int_ptr(ku) + + ALPHA = np.array(1.0, dtype=dtype) + INCX = val_to_int_ptr(1) + BETA = np.array(0.0, dtype=dtype) + Y = np.empty(m, dtype=dtype) + INCY = val_to_int_ptr(1) + + numba_gbmv( + TRANS, + M, + N, + KL, + KU, + ALPHA.view(w_type).ctypes, + A_banded.view(w_type).ctypes, + LDA, + x.view(w_type).ctypes, + INCX, + BETA.view(w_type).ctypes, + Y.view(w_type).ctypes, + INCY, + ) + + return Y + + return impl diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 4630224f02..bd4f098c1d 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -11,6 +11,7 @@ _pivot_to_permutation, ) from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor +from pytensor.link.numba.dispatch.linalg.dot.banded import _dot_banded from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd @@ -19,6 +20,7 @@ from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.tensor.slinalg import ( LU, + BandedDot, BlockDiagonal, Cholesky, CholeskySolve, @@ -311,3 +313,19 @@ def cho_solve(c, b): ) return cho_solve + + +@numba_funcify.register(BandedDot) +def numba_funcify_BandedDot(op, node, **kwargs): + kl = op.lower_diags + ku = op.upper_diags + dtype = node.inputs[0].dtype + + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) + + @numba_njit + def banded_dot(A, x): + return _dot_banded(A, x, kl=kl, ku=ku) + + return banded_dot diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 3880cca3c6..0998c66aae 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -15,8 +15,10 @@ LUFactor, Solve, SolveTriangular, + banded_dot, ) from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode +from tests.tensor.test_slinalg import _make_banded_A pytestmark = pytest.mark.filterwarnings("error") @@ -720,3 +722,24 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo # Can never destroy non-contiguous inputs np.testing.assert_allclose(b_val_not_contig, b_val) + + +def test_banded_dot(): + rng = np.random.default_rng() + + A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX) + x_val = rng.normal(size=(10,)).astype(config.floatX) + + A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) + x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) + + output = banded_dot(A, x, upper_diags=1, lower_diags=1) + + compare_numba_and_py( + [A, x], + output, + test_inputs=[A_val, x_val], + inplace=True, + numba_mode=numba_inplace_mode, + eval_obj_mode=False, + ) From 7d109b97c18dc4f12e5b031d90aff16d3472d466 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 17:54:07 +0800 Subject: [PATCH 11/25] Eliminate extra copy in numba impl --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index b6458c72bc..3ec2710ecc 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -3,7 +3,7 @@ import numpy as np from numba import njit as numba_njit from numba.core.extending import overload -from numba.np.linalg import _copy_to_fortran_order, ensure_blas, ensure_lapack +from numba.np.linalg import ensure_blas, ensure_lapack from scipy import linalg from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS @@ -15,9 +15,12 @@ @numba_njit(inline="always") -def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray: +def A_to_banded(A: np.ndarray, kl: int, ku: int, order="C") -> np.ndarray: m, n = A.shape - A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype) + if order == "C": + A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype) + else: + A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T for i, k in enumerate(range(ku, -kl - 1, -1)): if k >= 0: @@ -35,7 +38,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: """ fn = linalg.get_blas_funcs("gbmv", (A, x)) m, n = A.shape - A_banded = A_to_banded(A, kl=kl, ku=ku) + A_banded = A_to_banded(A, kl=kl, ku=ku, order="C") return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) @@ -54,9 +57,7 @@ def dot_banded_impl( def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: m, n = A.shape - # TODO: Can we avoid this copy? - A_banded = A_to_banded(A, kl=kl, ku=ku) - A_banded = _copy_to_fortran_order(A_banded) + A_banded = A_to_banded(A, kl=kl, ku=ku, order="F") TRANS = val_to_int_ptr(ord("N")) M = val_to_int_ptr(m) From c18f09501883a02d6ecc5c3b5f641cd71cda7d2a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 18:41:38 +0800 Subject: [PATCH 12/25] Create `A_banded` as F-contiguous array --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 2 +- pytensor/tensor/slinalg.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index 3ec2710ecc..6ceca5a0e4 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -38,7 +38,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: """ fn = linalg.get_blas_funcs("gbmv", (A, x)) m, n = A.shape - A_banded = A_to_banded(A, kl=kl, ku=ku, order="C") + A_banded = A_to_banded(A, kl=kl, ku=ku, order="F") return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bb9687f54a..07027f1868 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,7 +6,7 @@ import numpy as np import scipy.linalg as scipy_linalg -from numpy import diag, zeros +from numpy import zeros from numpy.exceptions import ComplexWarning import pytensor @@ -1697,10 +1697,13 @@ def perform(self, node, inputs, outputs_storage): kl = self.lower_diags ku = self.upper_diags - A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="C") + A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="F") for i, k in enumerate(range(ku, -kl - 1, -1)): - A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k) + if k >= 0: + A_banded[i, k:] = np.diag(A, k=k) + else: + A_banded[i, : n + k] = np.diag(A, k=k) fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) From 607a8718249b1d13384d5f41cd80da88460e67bb Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 18:41:43 +0800 Subject: [PATCH 13/25] Remove benchmark --- tests/tensor/test_slinalg.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 9fa8df768e..f75f2b8466 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1087,29 +1087,3 @@ def test_banded_dot(A_shape, kl, ku): out_val, out_2_val = fn(A_val, x_val) np.testing.assert_allclose(out_val, out_2_val) - - -@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str) -@pytest.mark.parametrize( - "A_shape", - [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)], - ids=["10", "100", "1000", "10_000"], -) -def test_banded_dot_perf(op, A_shape, benchmark): - rng = np.random.default_rng() - - A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX) - x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) - - A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) - x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) - - if op == "dot": - f = pt.dot - elif op == "banded_dot": - f = functools.partial(banded_dot, lower_diags=1, upper_diags=1) - - res = f(A, x) - fn = function([A, x], res, trust_input=True) - - benchmark(fn, A_val, x_val) From f6f12aa4c547ecc42af2c5a11375f5cb10791121 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 18:53:36 +0800 Subject: [PATCH 14/25] Don't cache numba function --- pytensor/link/numba/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index bd4f098c1d..b2893a6ecb 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -324,7 +324,7 @@ def numba_funcify_BandedDot(op, node, **kwargs): if dtype in complex_dtypes: raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) - @numba_njit + @numba_njit(cache=False) def banded_dot(A, x): return _dot_banded(A, x, kl=kl, ku=ku) From e8fe5e35e8566327adff34fe46ee2050b4ae8823 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 18:58:38 +0800 Subject: [PATCH 15/25] all hail mypy --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index 6ceca5a0e4..8946eef0d3 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any import numpy as np from numba import njit as numba_njit @@ -31,7 +32,7 @@ def A_to_banded(A: np.ndarray, kl: int, ku: int, order="C") -> np.ndarray: return A_banded -def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: +def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any: """ Thin wrapper around gmbv. This code will only be called if njit is disabled globally (e.g. during testing) From 5344c276e4b0b977ec8218226380bf9a3065984d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 21:01:08 +0800 Subject: [PATCH 16/25] set INCX by strides --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 2 +- tests/link/numba/test_slinalg.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index 8946eef0d3..8c04c007b0 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -69,7 +69,7 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: KU = val_to_int_ptr(ku) ALPHA = np.array(1.0, dtype=dtype) - INCX = val_to_int_ptr(1) + INCX = val_to_int_ptr(x.strides[0] // x.itemsize) BETA = np.array(0.0, dtype=dtype) Y = np.empty(m, dtype=dtype) INCY = val_to_int_ptr(1) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 0998c66aae..8153f91636 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -743,3 +743,15 @@ def test_banded_dot(): numba_mode=numba_inplace_mode, eval_obj_mode=False, ) + + # Test non-contiguous x input + x_val = rng.normal(size=(20,))[::2] + + compare_numba_and_py( + [A, x], + output, + test_inputs=[A_val, x_val], + inplace=True, + numba_mode=numba_inplace_mode, + eval_obj_mode=False, + ) From 31e9a29ac20e86a2b3afbe1d94b514bfe1a42e81 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 24 May 2025 21:01:23 +0800 Subject: [PATCH 17/25] relax tolerance of float32 test --- tests/tensor/test_slinalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f75f2b8466..dc4252afe3 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1086,4 +1086,7 @@ def test_banded_dot(A_shape, kl, ku): out_val, out_2_val = fn(A_val, x_val) - np.testing.assert_allclose(out_val, out_2_val) + atol = 1e-4 if config.floatX == "float32" else 1e-8 + rtol = 1e-4 if config.floatX == "float32" else 1e-8 + + np.testing.assert_allclose(out_val, out_2_val, atol=atol, rtol=rtol) From 0505c571370d763cea6bd640b323df780c366b8a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 12:36:18 +0800 Subject: [PATCH 18/25] Add suggestions --- pytensor/tensor/slinalg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 07027f1868..0ba750ba65 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1679,6 +1679,11 @@ def __init__(self, lower_diags, upper_diags): self.upper_diags = upper_diags def make_node(self, A, x): + if A.ndim != 2: + raise TypeError("A must be a 2D tensor") + if x.ndim != 1: + raise TypeError("x must be a 1D tensor") + A = as_tensor_variable(A) x = as_tensor_variable(x) @@ -1688,7 +1693,8 @@ def make_node(self, A, x): return pytensor.graph.basic.Apply(self, [A, x], [output]) def infer_shape(self, fgraph, nodes, shapes): - return [shapes[0][:-1]] + A_shape, _ = shapes + return [(A_shape[0],)] def perform(self, node, inputs, outputs_storage): A, x = inputs From 2b5c51d7e48d4f01b868d731f7137215bea04e38 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 12:36:25 +0800 Subject: [PATCH 19/25] Test strides --- tests/link/numba/test_slinalg.py | 8 ++++++-- tests/tensor/test_slinalg.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 8153f91636..ead78691d8 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -724,11 +724,15 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo np.testing.assert_allclose(b_val_not_contig, b_val) -def test_banded_dot(): +@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}") +def test_banded_dot(stride): rng = np.random.default_rng() A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX) - x_val = rng.normal(size=(10,)).astype(config.floatX) + + x_shape = (10 * abs(stride),) + x_val = rng.normal(size=x_shape).astype(config.floatX) + x_val = x_val[::stride] A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index dc4252afe3..e8a6407b6c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1061,20 +1061,20 @@ def _make_banded_A(A, kl, ku): return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags)) -@pytest.mark.parametrize( - "A_shape", - [ - (10, 10), - ], -) @pytest.mark.parametrize( "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] ) -def test_banded_dot(A_shape, kl, ku): +@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}") +def test_banded_dot(kl, ku, stride): rng = np.random.default_rng() - A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX) - x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX) + size = 10 + + A_val = _make_banded_A(rng.normal(size=(size, size)), kl=kl, ku=ku).astype( + config.floatX + ) + x_val = rng.normal(size=(size * abs(stride),)).astype(config.floatX) + x_val = x_val[::stride] A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) From 519c9335f1cbc42a90094c07f45be4b1f4eda5de Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 12:55:13 +0800 Subject: [PATCH 20/25] Add L_op --- pytensor/tensor/slinalg.py | 19 +++++++++++++++++++ tests/tensor/test_slinalg.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 0ba750ba65..2dbd0a6143 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -11,6 +11,7 @@ import pytensor import pytensor.tensor as pt +from pytensor import Variable from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op @@ -1714,6 +1715,24 @@ def perform(self, node, inputs, outputs_storage): fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], + ) -> list[Variable]: + # This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure + # is exploited. + A, x = inputs + (G_bar,) = output_grads + + A_bar = pt.outer(G_bar, x.T) + x_bar = banded_dot( + A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags + ) + + return [A_bar, x_bar] + def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int): """ diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index e8a6407b6c..2ca55dac1f 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1090,3 +1090,32 @@ def test_banded_dot(kl, ku, stride): rtol = 1e-4 if config.floatX == "float32" else 1e-8 np.testing.assert_allclose(out_val, out_2_val, atol=atol, rtol=rtol) + + +def test_banded_dot_grad(): + rng = np.random.default_rng() + size = 10 + + A_val = _make_banded_A(rng.normal(size=(size, size)), kl=1, ku=1).astype( + config.floatX + ) + x_val = rng.normal(size=(size,)).astype(config.floatX) + + def make_banded_pt(A): + # Like structured solve Ops, we have to incldue the transformation from an unconstrained matrix A to a banded + # matrix on the compute graph. Otherwise, the random perturbations used by verify_grad will result in invalid + # input matrices. + + diag_idxs = range(-1, 2) + diags = (pt.diag(A, k=k) for k in diag_idxs) + return sum(pt.diag(d, k=k) for k, d in zip(diag_idxs, diags)) + + def test_fn(A, x): + return banded_dot(make_banded_pt(A), x, lower_diags=1, upper_diags=1).sum() + + utt.verify_grad( + test_fn, + [A_val, x_val], + rng=rng, + eps=1e-4 if config.floatX == "float32" else 1e-8, + ) From 5754f931715dd7c382ea28c5265d45ba4ebd7d7f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 13:04:54 +0800 Subject: [PATCH 21/25] *remove* type hints to make mypy happy --- pytensor/tensor/slinalg.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 2dbd0a6143..4f2d7ed7d3 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1715,12 +1715,7 @@ def perform(self, node, inputs, outputs_storage): fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) - def L_op( - self, - inputs: Sequence[Variable], - outputs: Sequence[Variable], - output_grads: Sequence[Variable], - ) -> list[Variable]: + def L_op(self, inputs, outputs, output_grads) -> list[Variable]: # This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure # is exploited. A, x = inputs From 481814f43d465cb0c6f584c2d84500c815cd5c4c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 15:42:23 +0800 Subject: [PATCH 22/25] Remove order argument from numba A_to_banded --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index 8c04c007b0..f9be2ec30e 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -16,12 +16,12 @@ @numba_njit(inline="always") -def A_to_banded(A: np.ndarray, kl: int, ku: int, order="C") -> np.ndarray: +def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray: m, n = A.shape - if order == "C": - A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype) - else: - A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T + + # This matrix is build backwards then transposed to get it into Fortran order + # (order="F" is not allowed in Numba land) + A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T for i, k in enumerate(range(ku, -kl - 1, -1)): if k >= 0: @@ -39,7 +39,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any: """ fn = linalg.get_blas_funcs("gbmv", (A, x)) m, n = A.shape - A_banded = A_to_banded(A, kl=kl, ku=ku, order="F") + A_banded = A_to_banded(A, kl=kl, ku=ku) return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) @@ -58,7 +58,7 @@ def dot_banded_impl( def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: m, n = A.shape - A_banded = A_to_banded(A, kl=kl, ku=ku, order="F") + A_banded = A_to_banded(A, kl=kl, ku=ku) TRANS = val_to_int_ptr(ord("N")) M = val_to_int_ptr(m) From 30fece459cb1c1e6d8485c1ce8d3ce72baa3a6dd Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 25 May 2025 21:07:08 +0800 Subject: [PATCH 23/25] Incorporate feedback --- pytensor/link/numba/dispatch/linalg/dot/banded.py | 4 +++- pytensor/tensor/slinalg.py | 4 +--- tests/link/numba/test_slinalg.py | 15 +-------------- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py index f9be2ec30e..0740ad4d71 100644 --- a/pytensor/link/numba/dispatch/linalg/dot/banded.py +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -59,6 +59,7 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: m, n = A.shape A_banded = A_to_banded(A, kl=kl, ku=ku) + stride = x.strides[0] // x.itemsize TRANS = val_to_int_ptr(ord("N")) M = val_to_int_ptr(m) @@ -69,7 +70,8 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: KU = val_to_int_ptr(ku) ALPHA = np.array(1.0, dtype=dtype) - INCX = val_to_int_ptr(x.strides[0] // x.itemsize) + + INCX = val_to_int_ptr(stride) BETA = np.array(0.0, dtype=dtype) Y = np.empty(m, dtype=dtype) INCY = val_to_int_ptr(1) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4f2d7ed7d3..fc2142365b 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1722,9 +1722,7 @@ def L_op(self, inputs, outputs, output_grads) -> list[Variable]: (G_bar,) = output_grads A_bar = pt.outer(G_bar, x.T) - x_bar = banded_dot( - A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags - ) + x_bar = self(A.T, G_bar) return [A_bar, x_bar] diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index ead78691d8..3dd23d04ef 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -724,7 +724,7 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo np.testing.assert_allclose(b_val_not_contig, b_val) -@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}") +@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}") def test_banded_dot(stride): rng = np.random.default_rng() @@ -743,19 +743,6 @@ def test_banded_dot(stride): [A, x], output, test_inputs=[A_val, x_val], - inplace=True, - numba_mode=numba_inplace_mode, - eval_obj_mode=False, - ) - - # Test non-contiguous x input - x_val = rng.normal(size=(20,))[::2] - - compare_numba_and_py( - [A, x], - output, - test_inputs=[A_val, x_val], - inplace=True, numba_mode=numba_inplace_mode, eval_obj_mode=False, ) From 4bd259ca09418b5aac988ec60476986dc0337321 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 25 May 2025 21:23:00 +0800 Subject: [PATCH 24/25] Adjust numba test --- tests/link/numba/test_slinalg.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 3dd23d04ef..886612a4aa 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -724,25 +724,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo np.testing.assert_allclose(b_val_not_contig, b_val) -@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}") -def test_banded_dot(stride): +def test_banded_dot(): rng = np.random.default_rng() + A = pt.tensor("A", shape=(10, 10), dtype=config.floatX) A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX) - x_shape = (10 * abs(stride),) - x_val = rng.normal(size=x_shape).astype(config.floatX) - x_val = x_val[::stride] - - A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) - x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) + x = pt.tensor("x", shape=(10,), dtype=config.floatX) + x_val = rng.normal(size=(10,)).astype(config.floatX) output = banded_dot(A, x, upper_diags=1, lower_diags=1) - compare_numba_and_py( + fn, _ = compare_numba_and_py( [A, x], output, test_inputs=[A_val, x_val], numba_mode=numba_inplace_mode, eval_obj_mode=False, ) + + for stride in [2, -1, -2]: + x_shape = (10 * abs(stride),) + x_val = rng.normal(size=x_shape).astype(config.floatX) + x_val = x_val[::stride] + + nb_output = fn(A_val, x_val) + expected = A_val @ x_val + + np.testing.assert_allclose( + nb_output, + expected, + strict=True, + err_msg=f"Test failed for stride = {stride}", + ) From 497721efa8a8a1a065a97e0460edc64e11bf4024 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 25 May 2025 21:31:40 +0800 Subject: [PATCH 25/25] Remove more useful type information for mypy --- pytensor/tensor/slinalg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index fc2142365b..a1663cf462 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -11,7 +11,6 @@ import pytensor import pytensor.tensor as pt -from pytensor import Variable from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op @@ -1715,7 +1714,7 @@ def perform(self, node, inputs, outputs_storage): fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) - def L_op(self, inputs, outputs, output_grads) -> list[Variable]: + def L_op(self, inputs, outputs, output_grads): # This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure # is exploited. A, x = inputs