-
Notifications
You must be signed in to change notification settings - Fork 132
Implement BandedDot
Op
#1416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement BandedDot
Op
#1416
Changes from 22 commits
bbf3141
2282161
ae8eff6
b2f68a5
e64d4d3
c979e9d
f1066a9
0302fac
f47d88b
157345c
7d109b9
c18f095
607a871
f6f12aa
e8fe5e3
5344c27
31e9a29
0505c57
2b5c51d
519c933
5754f93
481814f
30fece4
4bd259c
497721e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import ctypes | ||
|
||
from numba.core.extending import get_cython_function_address | ||
from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind | ||
|
||
from pytensor.link.numba.dispatch.linalg._LAPACK import ( | ||
_get_float_pointer_for_dtype, | ||
_ptr_int, | ||
) | ||
|
||
|
||
def _get_blas_ptr_and_ptr_type(dtype, name): | ||
d = get_blas_kind(dtype) | ||
func_name = f"{d}{name}" | ||
float_pointer = _get_float_pointer_for_dtype(d) | ||
lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name) | ||
|
||
return lapack_ptr, float_pointer | ||
|
||
|
||
class _BLAS: | ||
""" | ||
Functions to return type signatures for wrapped BLAS functions. | ||
|
||
Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy. | ||
|
||
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 | ||
""" | ||
|
||
def __init__(self): | ||
ensure_lapack() | ||
ensure_blas() | ||
|
||
@classmethod | ||
def numba_xgbmv(cls, dtype): | ||
""" | ||
xGBMV performs one of the following matrix operations: | ||
|
||
y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y | ||
|
||
Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku | ||
super-diagonals. | ||
""" | ||
|
||
blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv") | ||
|
||
functype = ctypes.CFUNCTYPE( | ||
None, | ||
_ptr_int, # TRANS | ||
_ptr_int, # M | ||
_ptr_int, # N | ||
_ptr_int, # KL | ||
_ptr_int, # KU | ||
float_pointer, # ALPHA | ||
float_pointer, # A | ||
_ptr_int, # LDA | ||
float_pointer, # X | ||
_ptr_int, # INCX | ||
float_pointer, # BETA | ||
float_pointer, # Y | ||
_ptr_int, # INCY | ||
) | ||
|
||
return functype(blas_ptr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from collections.abc import Callable | ||
from typing import Any | ||
|
||
import numpy as np | ||
from numba import njit as numba_njit | ||
from numba.core.extending import overload | ||
from numba.np.linalg import ensure_blas, ensure_lapack | ||
from scipy import linalg | ||
|
||
from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS | ||
from pytensor.link.numba.dispatch.linalg._LAPACK import ( | ||
_get_underlying_float, | ||
val_to_int_ptr, | ||
) | ||
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix | ||
|
||
|
||
@numba_njit(inline="always") | ||
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray: | ||
m, n = A.shape | ||
|
||
# This matrix is build backwards then transposed to get it into Fortran order | ||
# (order="F" is not allowed in Numba land) | ||
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T | ||
|
||
for i, k in enumerate(range(ku, -kl - 1, -1)): | ||
if k >= 0: | ||
A_banded[i, k:] = np.diag(A, k=k) | ||
else: | ||
A_banded[i, : n + k] = np.diag(A, k=k) | ||
|
||
return A_banded | ||
|
||
|
||
def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any: | ||
""" | ||
Thin wrapper around gmbv. This code will only be called if njit is disabled globally | ||
(e.g. during testing) | ||
""" | ||
fn = linalg.get_blas_funcs("gbmv", (A, x)) | ||
m, n = A.shape | ||
A_banded = A_to_banded(A, kl=kl, ku=ku) | ||
|
||
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) | ||
|
||
|
||
@overload(_dot_banded) | ||
def dot_banded_impl( | ||
A: np.ndarray, x: np.ndarray, kl: int, ku: int | ||
) -> Callable[[np.ndarray, np.ndarray, int, int], np.ndarray]: | ||
ensure_lapack() | ||
ensure_blas() | ||
_check_scipy_linalg_matrix(A, "dot_banded") | ||
dtype = A.dtype | ||
w_type = _get_underlying_float(dtype) | ||
numba_gbmv = _BLAS().numba_xgbmv(dtype) | ||
|
||
def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray: | ||
m, n = A.shape | ||
|
||
A_banded = A_to_banded(A, kl=kl, ku=ku) | ||
|
||
TRANS = val_to_int_ptr(ord("N")) | ||
M = val_to_int_ptr(m) | ||
N = val_to_int_ptr(n) | ||
LDA = val_to_int_ptr(A_banded.shape[0]) | ||
|
||
KL = val_to_int_ptr(kl) | ||
KU = val_to_int_ptr(ku) | ||
|
||
ALPHA = np.array(1.0, dtype=dtype) | ||
INCX = val_to_int_ptr(x.strides[0] // x.itemsize) | ||
BETA = np.array(0.0, dtype=dtype) | ||
Y = np.empty(m, dtype=dtype) | ||
INCY = val_to_int_ptr(1) | ||
|
||
numba_gbmv( | ||
TRANS, | ||
M, | ||
N, | ||
KL, | ||
KU, | ||
ALPHA.view(w_type).ctypes, | ||
A_banded.view(w_type).ctypes, | ||
LDA, | ||
x.view(w_type).ctypes, | ||
INCX, | ||
BETA.view(w_type).ctypes, | ||
Y.view(w_type).ctypes, | ||
INCY, | ||
) | ||
|
||
return Y | ||
|
||
return impl |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,12 @@ | |
|
||
import numpy as np | ||
import scipy.linalg as scipy_linalg | ||
from numpy import zeros | ||
from numpy.exceptions import ComplexWarning | ||
|
||
import pytensor | ||
import pytensor.tensor as pt | ||
from pytensor import Variable | ||
from pytensor.gradient import DisconnectedType | ||
from pytensor.graph.basic import Apply | ||
from pytensor.graph.op import Op | ||
|
@@ -1669,6 +1671,92 @@ | |
return _block_diagonal_matrix(*matrices) | ||
|
||
|
||
class BandedDot(Op): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put in blas.py? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw your message, fine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean in |
||
__props__ = ("lower_diags", "upper_diags") | ||
gufunc_signature = "(m,n),(n)->(m)" | ||
|
||
def __init__(self, lower_diags, upper_diags): | ||
self.lower_diags = lower_diags | ||
self.upper_diags = upper_diags | ||
|
||
def make_node(self, A, x): | ||
if A.ndim != 2: | ||
raise TypeError("A must be a 2D tensor") | ||
if x.ndim != 1: | ||
raise TypeError("x must be a 1D tensor") | ||
|
||
A = as_tensor_variable(A) | ||
x = as_tensor_variable(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raise ValueError for non core ndims |
||
|
||
out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrong for integers/should raise. Also reject complex? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this from other make_node in slinalg (eigvalsh, eigvalsh grad, solve lyapunov stuff). What's the right way to upcast here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The right way is to predict what scipy outputs. Some Ops are lazy and just call scipy with a minimal input case to find out the output type. I don't love that. Which makes me wonder I guess numba/direct call to xbmv doesn't work with integers arrays, so we may need to cast/raise? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does JAX do on integer inputs? Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No idea, cast them to float or call a dot function that works on integers?
I think it's a bit crazy, you could add a function with lru_cache on the dtypes, that tries it and stores the result. Most combinations will never be needed. And we don't want to do it at import time |
||
output = x.type.clone(dtype=out_dtype)() | ||
|
||
return pytensor.graph.basic.Apply(self, [A, x], [output]) | ||
|
||
def infer_shape(self, fgraph, nodes, shapes): | ||
A_shape, _ = shapes | ||
return [(A_shape[0],)] | ||
|
||
def perform(self, node, inputs, outputs_storage): | ||
A, x = inputs | ||
m, n = A.shape | ||
|
||
kl = self.lower_diags | ||
ku = self.upper_diags | ||
|
||
A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="F") | ||
|
||
for i, k in enumerate(range(ku, -kl - 1, -1)): | ||
if k >= 0: | ||
A_banded[i, k:] = np.diag(A, k=k) | ||
else: | ||
A_banded[i, : n + k] = np.diag(A, k=k) | ||
|
||
fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype) | ||
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x) | ||
|
||
def L_op(self, inputs, outputs, output_grads) -> list[Variable]: | ||
# This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure | ||
# is exploited. | ||
A, x = inputs | ||
(G_bar,) = output_grads | ||
|
||
A_bar = pt.outer(G_bar, x.T) | ||
x_bar = banded_dot( | ||
A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags | ||
) | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return [A_bar, x_bar] | ||
|
||
|
||
def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int): | ||
""" | ||
Specialized matrix-vector multiplication for cases when A is a banded matrix | ||
|
||
No type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. This will lead | ||
to incorrect results if A is not actually a banded matrix. | ||
|
||
Unlike dot, this function is only valid if b is a vector. | ||
|
||
Parameters | ||
---------- | ||
A: Tensorlike | ||
Matrix to perform banded dot on. | ||
x: Tensorlike | ||
Vector to perform banded dot on. | ||
lower_diags: int | ||
Number of nonzero lower diagonals of A | ||
upper_diags: int | ||
Number of nonzero upper diagonals of A | ||
|
||
Returns | ||
------- | ||
out: Tensor | ||
The matrix multiplication result | ||
""" | ||
return Blockwise(BandedDot(lower_diags, upper_diags))(A, x) | ||
|
||
|
||
__all__ = [ | ||
"cholesky", | ||
"solve", | ||
|
@@ -1683,4 +1771,5 @@ | |
"lu", | ||
"lu_factor", | ||
"lu_solve", | ||
"banded_dot", | ||
] |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test non-unit positive and negative strides for x. In C Gemv for instance we need to point to the last memory position when strides is negative
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can, but need not also test for A, y, strides. Since we're creating them now ourselves we know they're always correct. But once we split the Op we will need to worry about those as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmk what you think about the way I'm testing strides now, and I can expand it if it's adequate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unresolving this because the negative stride tests are failing
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's like the Cblas, when you have negative strides, you have to point to the end of the numpy array (x[-1]). Blas wants to know where the block of memory starts, even if it iterates in reverse, but numpy points to the end of the array when it has negative strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytensor/pytensor/tensor/blas_c.py
Lines 243 to 249 in 2d414d4