Skip to content

Implement BandedDot Op #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bbf3141
Naive implementation, do not merge
jessegrabowski May 23, 2025
2282161
Implement suggestions
jessegrabowski May 23, 2025
ae8eff6
Simplify perf test
jessegrabowski May 23, 2025
b2f68a5
float32 compat in tests
jessegrabowski May 23, 2025
e64d4d3
Remove np.pad
jessegrabowski May 23, 2025
c979e9d
set dtype correctly
jessegrabowski May 23, 2025
f1066a9
fix signature, add infer_shape
jessegrabowski May 23, 2025
0302fac
micro-optimizations
jessegrabowski May 23, 2025
f47d88b
Rename b to x, matching BLAS docs
jessegrabowski May 24, 2025
157345c
Add numba dispatch for banded_dot
jessegrabowski May 24, 2025
7d109b9
Eliminate extra copy in numba impl
jessegrabowski May 24, 2025
c18f095
Create `A_banded` as F-contiguous array
jessegrabowski May 24, 2025
607a871
Remove benchmark
jessegrabowski May 24, 2025
f6f12aa
Don't cache numba function
jessegrabowski May 24, 2025
e8fe5e3
all hail mypy
jessegrabowski May 24, 2025
5344c27
set INCX by strides
jessegrabowski May 24, 2025
31e9a29
relax tolerance of float32 test
jessegrabowski May 24, 2025
0505c57
Add suggestions
jessegrabowski May 25, 2025
2b5c51d
Test strides
jessegrabowski May 25, 2025
519c933
Add L_op
jessegrabowski May 25, 2025
5754f93
*remove* type hints to make mypy happy
jessegrabowski May 25, 2025
481814f
Remove order argument from numba A_to_banded
jessegrabowski May 25, 2025
30fece4
Incorporate feedback
jessegrabowski May 25, 2025
4bd259c
Adjust numba test
jessegrabowski May 25, 2025
497721e
Remove more useful type information for mypy
jessegrabowski May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor|banded_dot)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
Expand Down
64 changes: 64 additions & 0 deletions pytensor/link/numba/dispatch/linalg/_BLAS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import ctypes

from numba.core.extending import get_cython_function_address
from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_get_float_pointer_for_dtype,
_ptr_int,
)


def _get_blas_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name)

return lapack_ptr, float_pointer


class _BLAS:
"""
Functions to return type signatures for wrapped BLAS functions.

Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy.

Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""

def __init__(self):
ensure_lapack()
ensure_blas()

@classmethod
def numba_xgbmv(cls, dtype):
"""
xGBMV performs one of the following matrix operations:

y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y

Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku
super-diagonals.
"""

blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv")

functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # M
_ptr_int, # N
_ptr_int, # KL
_ptr_int, # KU
float_pointer, # ALPHA
float_pointer, # A
_ptr_int, # LDA
float_pointer, # X
_ptr_int, # INCX
float_pointer, # BETA
float_pointer, # Y
_ptr_int, # INCY
)

return functype(blas_ptr)
Empty file.
95 changes: 95 additions & 0 deletions pytensor/link/numba/dispatch/linalg/dot/banded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import Callable
from typing import Any

import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_blas, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_get_underlying_float,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix


@numba_njit(inline="always")
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
m, n = A.shape

Check warning on line 20 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L20

Added line #L20 was not covered by tests

# This matrix is build backwards then transposed to get it into Fortran order
# (order="F" is not allowed in Numba land)
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T

Check warning on line 24 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L24

Added line #L24 was not covered by tests

for i, k in enumerate(range(ku, -kl - 1, -1)):
if k >= 0:
A_banded[i, k:] = np.diag(A, k=k)

Check warning on line 28 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L28

Added line #L28 was not covered by tests
else:
A_banded[i, : n + k] = np.diag(A, k=k)

Check warning on line 30 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L30

Added line #L30 was not covered by tests

return A_banded

Check warning on line 32 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L32

Added line #L32 was not covered by tests


def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any:
"""
Thin wrapper around gmbv. This code will only be called if njit is disabled globally
(e.g. during testing)
"""
fn = linalg.get_blas_funcs("gbmv", (A, x))
m, n = A.shape
A_banded = A_to_banded(A, kl=kl, ku=ku)

Check warning on line 42 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L40-L42

Added lines #L40 - L42 were not covered by tests

return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)

Check warning on line 44 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L44

Added line #L44 was not covered by tests


@overload(_dot_banded)
def dot_banded_impl(
A: np.ndarray, x: np.ndarray, kl: int, ku: int
) -> Callable[[np.ndarray, np.ndarray, int, int], np.ndarray]:
ensure_lapack()
ensure_blas()
_check_scipy_linalg_matrix(A, "dot_banded")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gbmv = _BLAS().numba_xgbmv(dtype)

def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
m, n = A.shape

Check warning on line 59 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L59

Added line #L59 was not covered by tests

A_banded = A_to_banded(A, kl=kl, ku=ku)

Check warning on line 61 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L61

Added line #L61 was not covered by tests

TRANS = val_to_int_ptr(ord("N"))
M = val_to_int_ptr(m)
N = val_to_int_ptr(n)
LDA = val_to_int_ptr(A_banded.shape[0])

Check warning on line 66 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L63-L66

Added lines #L63 - L66 were not covered by tests

KL = val_to_int_ptr(kl)
KU = val_to_int_ptr(ku)

Check warning on line 69 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L68-L69

Added lines #L68 - L69 were not covered by tests

ALPHA = np.array(1.0, dtype=dtype)
INCX = val_to_int_ptr(x.strides[0] // x.itemsize)
Copy link
Member

@ricardoV94 ricardoV94 May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test non-unit positive and negative strides for x. In C Gemv for instance we need to point to the last memory position when strides is negative

Copy link
Member

@ricardoV94 ricardoV94 May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can, but need not also test for A, y, strides. Since we're creating them now ourselves we know they're always correct. But once we split the Op we will need to worry about those as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lmk what you think about the way I'm testing strides now, and I can expand it if it's adequate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unresolving this because the negative stride tests are failing

Copy link
Member

@ricardoV94 ricardoV94 May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's like the Cblas, when you have negative strides, you have to point to the end of the numpy array (x[-1]). Blas wants to know where the block of memory starts, even if it iterates in reverse, but numpy points to the end of the array when it has negative strides.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;

BETA = np.array(0.0, dtype=dtype)
Y = np.empty(m, dtype=dtype)
INCY = val_to_int_ptr(1)

Check warning on line 75 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L71-L75

Added lines #L71 - L75 were not covered by tests

numba_gbmv(

Check warning on line 77 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L77

Added line #L77 was not covered by tests
TRANS,
M,
N,
KL,
KU,
ALPHA.view(w_type).ctypes,
A_banded.view(w_type).ctypes,
LDA,
x.view(w_type).ctypes,
INCX,
BETA.view(w_type).ctypes,
Y.view(w_type).ctypes,
INCY,
)

return Y

Check warning on line 93 in pytensor/link/numba/dispatch/linalg/dot/banded.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/dot/banded.py#L93

Added line #L93 was not covered by tests

return impl
18 changes: 18 additions & 0 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_pivot_to_permutation,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
from pytensor.link.numba.dispatch.linalg.dot.banded import _dot_banded
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
Expand All @@ -19,6 +20,7 @@
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.tensor.slinalg import (
LU,
BandedDot,
BlockDiagonal,
Cholesky,
CholeskySolve,
Expand Down Expand Up @@ -311,3 +313,19 @@
)

return cho_solve


@numba_funcify.register(BandedDot)
def numba_funcify_BandedDot(op, node, **kwargs):
kl = op.lower_diags
ku = op.upper_diags
dtype = node.inputs[0].dtype

if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))

Check warning on line 325 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L325

Added line #L325 was not covered by tests

@numba_njit(cache=False)
def banded_dot(A, x):
return _dot_banded(A, x, kl=kl, ku=ku)

Check warning on line 329 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L329

Added line #L329 was not covered by tests

return banded_dot
89 changes: 89 additions & 0 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import numpy as np
import scipy.linalg as scipy_linalg
from numpy import zeros
from numpy.exceptions import ComplexWarning

import pytensor
import pytensor.tensor as pt
from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
Expand Down Expand Up @@ -1669,6 +1671,92 @@
return _block_diagonal_matrix(*matrices)


class BandedDot(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in blas.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your message, fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in pytensor.tensor.blas ? I can do that if you think it's better

__props__ = ("lower_diags", "upper_diags")
gufunc_signature = "(m,n),(n)->(m)"

def __init__(self, lower_diags, upper_diags):
self.lower_diags = lower_diags
self.upper_diags = upper_diags

def make_node(self, A, x):
if A.ndim != 2:
raise TypeError("A must be a 2D tensor")

Check warning on line 1684 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L1684

Added line #L1684 was not covered by tests
if x.ndim != 1:
raise TypeError("x must be a 1D tensor")

Check warning on line 1686 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L1686

Added line #L1686 was not covered by tests

A = as_tensor_variable(A)
x = as_tensor_variable(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise ValueError for non core ndims


out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong for integers/should raise. Also reject complex?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from other make_node in slinalg (eigvalsh, eigvalsh grad, solve lyapunov stuff). What's the right way to upcast here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The right way is to predict what scipy outputs. Some Ops are lazy and just call scipy with a minimal input case to find out the output type. I don't love that.

Which makes me wonder I guess numba/direct call to xbmv doesn't work with integers arrays, so we may need to cast/raise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does JAX do on integer inputs?

Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does JAX do on integer inputs?

No idea, cast them to float or call a dot function that works on integers?

Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?

I think it's a bit crazy, you could add a function with lru_cache on the dtypes, that tries it and stores the result. Most combinations will never be needed. And we don't want to do it at import time

output = x.type.clone(dtype=out_dtype)()

return pytensor.graph.basic.Apply(self, [A, x], [output])

def infer_shape(self, fgraph, nodes, shapes):
A_shape, _ = shapes
return [(A_shape[0],)]

def perform(self, node, inputs, outputs_storage):
A, x = inputs
m, n = A.shape

kl = self.lower_diags
ku = self.upper_diags

A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="F")

for i, k in enumerate(range(ku, -kl - 1, -1)):
if k >= 0:
A_banded[i, k:] = np.diag(A, k=k)
else:
A_banded[i, : n + k] = np.diag(A, k=k)

fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype)
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)

def L_op(self, inputs, outputs, output_grads) -> list[Variable]:
# This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure
# is exploited.
A, x = inputs
(G_bar,) = output_grads

A_bar = pt.outer(G_bar, x.T)
x_bar = banded_dot(
A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags
)

return [A_bar, x_bar]


def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int):
"""
Specialized matrix-vector multiplication for cases when A is a banded matrix

No type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. This will lead
to incorrect results if A is not actually a banded matrix.

Unlike dot, this function is only valid if b is a vector.

Parameters
----------
A: Tensorlike
Matrix to perform banded dot on.
x: Tensorlike
Vector to perform banded dot on.
lower_diags: int
Number of nonzero lower diagonals of A
upper_diags: int
Number of nonzero upper diagonals of A

Returns
-------
out: Tensor
The matrix multiplication result
"""
return Blockwise(BandedDot(lower_diags, upper_diags))(A, x)


__all__ = [
"cholesky",
"solve",
Expand All @@ -1683,4 +1771,5 @@
"lu",
"lu_factor",
"lu_solve",
"banded_dot",
]
39 changes: 39 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
LUFactor,
Solve,
SolveTriangular,
banded_dot,
)
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
from tests.tensor.test_slinalg import _make_banded_A


pytestmark = pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -720,3 +722,40 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo

# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)


@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}")
def test_banded_dot(stride):
rng = np.random.default_rng()

A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)

x_shape = (10 * abs(stride),)
x_val = rng.normal(size=x_shape).astype(config.floatX)
x_val = x_val[::stride]

A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)

output = banded_dot(A, x, upper_diags=1, lower_diags=1)

compare_numba_and_py(
[A, x],
output,
test_inputs=[A_val, x_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)

# Test non-contiguous x input
x_val = rng.normal(size=(20,))[::2]

compare_numba_and_py(
[A, x],
output,
test_inputs=[A_val, x_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
Loading