Skip to content

Implement BandedDot Op #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bbf3141
Naive implementation, do not merge
jessegrabowski May 23, 2025
2282161
Implement suggestions
jessegrabowski May 23, 2025
ae8eff6
Simplify perf test
jessegrabowski May 23, 2025
b2f68a5
float32 compat in tests
jessegrabowski May 23, 2025
e64d4d3
Remove np.pad
jessegrabowski May 23, 2025
c979e9d
set dtype correctly
jessegrabowski May 23, 2025
f1066a9
fix signature, add infer_shape
jessegrabowski May 23, 2025
0302fac
micro-optimizations
jessegrabowski May 23, 2025
f47d88b
Rename b to x, matching BLAS docs
jessegrabowski May 24, 2025
157345c
Add numba dispatch for banded_dot
jessegrabowski May 24, 2025
7d109b9
Eliminate extra copy in numba impl
jessegrabowski May 24, 2025
c18f095
Create `A_banded` as F-contiguous array
jessegrabowski May 24, 2025
607a871
Remove benchmark
jessegrabowski May 24, 2025
f6f12aa
Don't cache numba function
jessegrabowski May 24, 2025
e8fe5e3
all hail mypy
jessegrabowski May 24, 2025
5344c27
set INCX by strides
jessegrabowski May 24, 2025
31e9a29
relax tolerance of float32 test
jessegrabowski May 24, 2025
0505c57
Add suggestions
jessegrabowski May 25, 2025
2b5c51d
Test strides
jessegrabowski May 25, 2025
519c933
Add L_op
jessegrabowski May 25, 2025
5754f93
*remove* type hints to make mypy happy
jessegrabowski May 25, 2025
481814f
Remove order argument from numba A_to_banded
jessegrabowski May 25, 2025
30fece4
Incorporate feedback
jessegrabowski May 25, 2025
4bd259c
Adjust numba test
jessegrabowski May 25, 2025
497721e
Remove more useful type information for mypy
jessegrabowski May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,98 @@ def block_diag(*matrices: TensorVariable):
return _block_diagonal_matrix(*matrices)


def _to_banded_form(A, kl, ku):
"""
Convert a full matrix A to LAPACK banded form for gbmv.

Parameters
----------
A: np.ndarray
(m, n) banded matrix with nonzero values on the diagonals
kl: int
Number of nonzero lower diagonals of A
ku: int
Number of nonzero upper diagonals of A

Returns
-------
ab: np.ndarray
(kl + ku + 1, n) banded matrix suitable for LAPACK
"""
A = np.asarray(A)
m, n = A.shape
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C")

for i, k in enumerate(range(ku, -kl - 1, -1)):
col_slice = slice(k, None) if k >= 0 else slice(None, n + k)
ab[i, col_slice] = np.diag(A, k=k)

return ab


_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause import time overhead to PyTensor.

I'm okay paying the extra 3us at runtime instead since virtually nobody will ever use this (or use it in a case where they need those extra us)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this as well. It won't stay in the final verison.

Copy link
Member

@ricardoV94 ricardoV94 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can exploit prepare_node and add the function to node.tag, which the perform method can then retrieve from. That's two attribute accesses instead of a string check / scipy caching...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can sidestep perform and use make_thunk instead



class BandedDot(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in blas.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your message, fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in pytensor.tensor.blas ? I can do that if you think it's better

__props__ = ("lower_diags", "upper_diags")
gufunc_signature = "(m,n),(n)->(n)"

def __init__(self, lower_diags, upper_diags):
self.lower_diags = lower_diags
self.upper_diags = upper_diags

def make_node(self, A, b):
A = as_tensor_variable(A)
B = as_tensor_variable(b)

out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is wrong for integer types

output = b.type().astype(out_dtype)

return pytensor.graph.basic.Apply(self, [A, B], [output])

def perform(self, node, inputs, outputs_storage):
A, b = inputs
m, n = A.shape
alpha = 1

kl = self.lower_diags
ku = self.upper_diags

A_banded = _to_banded_form(A, kl, ku)

fn = _dgbmv if A.dtype == "float64" else _sgbmv
outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b)


def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int):
"""
Specialized matrix-vector multiplication for cases when A is a banded matrix

No type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. This will lead
to incorrect results if A is not actually a banded matrix.

Unlike dot, this function is only valid if b is a vector.

Parameters
----------
A: Tensorlike
Matrix to perform banded dot on.
b: Tensorlike
Vector to perform banded dot on.
lower_diags: int
Number of nonzero lower diagonals of A
upper_diags: int
Number of nonzero upper diagonals of A

Returns
-------
out: Tensor
The matrix multiplication result
"""
return Blockwise(BandedDot(lower_diags, upper_diags))(A, b)


__all__ = [
"cholesky",
"solve",
Expand All @@ -1683,4 +1775,5 @@ def block_diag(*matrices: TensorVariable):
"lu",
"lu_factor",
"lu_solve",
"banded_dot",
]
62 changes: 62 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from pytensor.graph.basic import equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.slinalg import (
BandedDot,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
banded_dot,
block_diag,
cho_solve,
cholesky,
Expand Down Expand Up @@ -1051,3 +1053,63 @@ def test_block_diagonal_blockwise():
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval()
assert result.shape == (10, batch_size, 6, 6)


def _make_banded_A(A, kl, ku):
diag_idxs = range(-kl, ku + 1)
diags = (np.diag(A, k=k) for k in diag_idxs)
return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags))


@pytest.mark.parametrize(
"A_shape",
[
(10, 10),
],
)
@pytest.mark.parametrize(
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
)
def test_banded_dot(A_shape, kl, ku):
rng = np.random.default_rng()

A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX)
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)

A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype)
res = banded_dot(A, b, kl, ku)
res_2 = A @ b

fn = function([A, b], [res, res_2], trust_input=True)
assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes)

x_val, x2_val = fn(A_val, b_val)

np.testing.assert_allclose(x_val, x2_val)


@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str)
@pytest.mark.parametrize(
"A_shape",
[(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)],
ids=["10", "100", "1000", "10_000"],
)
def test_banded_dot_perf(op, A_shape, benchmark):
rng = np.random.default_rng()

A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX)
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)

A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
b = pt.tensor("b", shape=b_val.shape, dtype=A_val.dtype)

if op == "dot":
f = pt.dot
elif op == "banded_dot":
f = functools.partial(banded_dot, lower_diags=1, upper_diags=1)

res = f(A, b)
fn = function([A, b], res, trust_input=True)

benchmark(fn, A_val, b_val)
Loading