-
Notifications
You must be signed in to change notification settings - Fork 132
Implement BandedDot
Op
#1416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement BandedDot
Op
#1416
Conversation
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
912a88d
to
2282161
Compare
I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):
|
pytensor/tensor/slinalg.py
Outdated
A = np.asarray(A) | ||
m, n = A.shape | ||
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C") | ||
|
||
for i, k in enumerate(range(ku, -kl - 1, -1)): | ||
padding = (k, 0) if k >= 0 else (0, -k) | ||
diag = np.pad(np.diag(A, k=k), padding) | ||
ab[i, :] = diag | ||
|
||
return ab |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine this explains most of the python overhead for small cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one way or another we have to do that though as part of the cost of the Op. Unless we demand users have inputs ready in that form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's fine, I was just thinking out loud.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rearrangement could be done symbolically in a wrapper Op that calls the blas Op (which expects things to be ready in the correct form)
It might also be better to do smart column indexing on ab
instead of using pad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's similar to the Solve, in that you can also do it once and reuse many times possibly, but I think that's too much micro-optimization for now. We also don't want to autodiff through it
pytensor/tensor/slinalg.py
Outdated
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64") | ||
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause import time overhead to PyTensor.
I'm okay paying the extra 3us at runtime instead since virtually nobody will ever use this (or use it in a case where they need those extra us)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this as well. It won't stay in the final verison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can exploit prepare_node
and add the function to node.tag
, which the perform method can then retrieve from. That's two attribute accesses instead of a string check / scipy caching...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you can sidestep perform
and use make_thunk
instead
I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big |
Benchmark after tuning up the
|
That looks much better! |
I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often |
100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever Edit: those are on my machine, don't know about yours |
This is the best I think we can get out of this in python? def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
kl = self.lower_diags
ku = self.upper_diags
if node.outputs[0].dtype == "float64":
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
else:
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
ab_size = kl + ku + 1
a_storage = storage_map[node.inputs[0]]
b_storage = storage_map[node.inputs[1]]
out_storage = storage_map[node.outputs[0]]
out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False]
def thunk(
a_storage=a_storage,
b_storage=b_storage,
out_storage=out_storage,
out_computed=out_computed,
kl=kl,
ku=ku,
ab_size=ab_size,
gbmv=gbmv,
):
A = a_storage[0]
b = b_storage[0]
m, n = A.shape
ab = np.zeros((ab_size, n), dtype=A.dtype, order="C")
for i, k in enumerate(range(ku, -kl - 1, -1)):
if k > 0:
ab[i, k:] = diag(A, k=k)
else:
ab[i, :n + k] = diag(A, k=k)
out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b)
out_computed[0] = True
return thunk |
pytensor/tensor/slinalg.py
Outdated
A = as_tensor_variable(A) | ||
B = as_tensor_variable(b) | ||
|
||
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is wrong for integer types
That's much more palatable. The difference between numba/python gbmv is also what you should expect to see if you implemented gbmv in C so you don't have to wonder. |
I'd like to call this one done for now, although there are three major things that are left to do:
I want to merge this then do these 3 things because I want to do #1418 first, and put the resulting function into the new I also need to think about how to handle the splitting out of the BandedMatrix Op, because it destroys information about how many rows the input matrix has (gemv needs to know this). |
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (72.72%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1416 +/- ##
==========================================
- Coverage 82.11% 82.09% -0.02%
==========================================
Files 211 213 +2
Lines 49686 49843 +157
Branches 8813 8827 +14
==========================================
+ Hits 40798 40920 +122
- Misses 6710 6740 +30
- Partials 2178 2183 +5
🚀 New features to boost your workflow:
|
A = as_tensor_variable(A) | ||
x = as_tensor_variable(x) | ||
|
||
out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong for integers/should raise. Also reject complex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this from other make_node in slinalg (eigvalsh, eigvalsh grad, solve lyapunov stuff). What's the right way to upcast here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The right way is to predict what scipy outputs. Some Ops are lazy and just call scipy with a minimal input case to find out the output type. I don't love that.
Which makes me wonder I guess numba/direct call to xbmv doesn't work with integers arrays, so we may need to cast/raise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does JAX do on integer inputs?
Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does JAX do on integer inputs?
No idea, cast them to float or call a dot function that works on integers?
Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?
I think it's a bit crazy, you could add a function with lru_cache on the dtypes, that tries it and stores the result. Most combinations will never be needed. And we don't want to do it at import time
|
||
def make_node(self, A, x): | ||
A = as_tensor_variable(A) | ||
x = as_tensor_variable(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise ValueError for non core ndims
@@ -1669,6 +1670,73 @@ def block_diag(*matrices: TensorVariable): | |||
return _block_diagonal_matrix(*matrices) | |||
|
|||
|
|||
class BandedDot(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put in blas.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw your message, fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean in pytensor.tensor.blas
? I can do that if you think it's better
KU = val_to_int_ptr(ku) | ||
|
||
ALPHA = np.array(1.0, dtype=dtype) | ||
INCX = val_to_int_ptr(x.strides[0] // x.itemsize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test non-unit positive and negative strides for x. In C Gemv for instance we need to point to the last memory position when strides is negative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can, but need not also test for A, y, strides. Since we're creating them now ourselves we know they're always correct. But once we split the Op we will need to worry about those as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmk what you think about the way I'm testing strides now, and I can expand it if it's adequate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unresolving this because the negative stride tests are failing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's like the Cblas, when you have negative strides, you have to point to the end of the numpy array (x[-1]). Blas wants to know where the block of memory starts, even if it iterates in reverse, but numpy points to the end of the array when it has negative strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytensor/pytensor/tensor/blas_c.py
Lines 243 to 249 in 2d414d4
// gemv expects pointers to the beginning of memory arrays, | |
// but numpy provides provides a pointer to the first element, | |
// so when the stride is negative, we need to get the last one. | |
if (Sx < 0) | |
x_data += (Nz0 - 1) * Sx; | |
if (Sy < 0) | |
y_data += (Nz1 - 1) * Sy; |
c396a8a
to
4bd259c
Compare
@ricardoV94 since #1418 got resolved without adding the GEMV rewrite to numba, how should I handle expanding this Op to include rank-1 updates? |
We may still link directly to blas for the full update, not sure numba does it besides dispatching the matrix/vector dot part |
I would start by benchmarking directly with numba to see if we get a speedup from calling the fused gemv op directly or if numba does it (the regular one, it for sure doesn't do it for gbmv) |
Description
This PR adds a
BandedDot
Op
that usesgbmv
to do matrix-vector multiplication for the case that A is a banded matrix.In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:
I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.
Related Issue
linalg.BandedDot
#1415Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/