-
Notifications
You must be signed in to change notification settings - Fork 131
Implement BandedDot
Op
#1416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement BandedDot
Op
#1416
Conversation
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
912a88d
to
2282161
Compare
I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):
|
pytensor/tensor/slinalg.py
Outdated
A = np.asarray(A) | ||
m, n = A.shape | ||
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C") | ||
|
||
for i, k in enumerate(range(ku, -kl - 1, -1)): | ||
padding = (k, 0) if k >= 0 else (0, -k) | ||
diag = np.pad(np.diag(A, k=k), padding) | ||
ab[i, :] = diag | ||
|
||
return ab |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine this explains most of the python overhead for small cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one way or another we have to do that though as part of the cost of the Op. Unless we demand users have inputs ready in that form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's fine, I was just thinking out loud.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rearrangement could be done symbolically in a wrapper Op that calls the blas Op (which expects things to be ready in the correct form)
It might also be better to do smart column indexing on ab
instead of using pad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's similar to the Solve, in that you can also do it once and reuse many times possibly, but I think that's too much micro-optimization for now. We also don't want to autodiff through it
pytensor/tensor/slinalg.py
Outdated
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64") | ||
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause import time overhead to PyTensor.
I'm okay paying the extra 3us at runtime instead since virtually nobody will ever use this (or use it in a case where they need those extra us)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this as well. It won't stay in the final verison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can exploit prepare_node
and add the function to node.tag
, which the perform method can then retrieve from. That's two attribute accesses instead of a string check / scipy caching...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you can sidestep perform
and use make_thunk
instead
I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big |
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") | ||
|
||
|
||
class BandedDot(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_shape / L_op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_shape yes, L_op I'm waiting to make the forward pass not suck first
pytensor/tensor/slinalg.py
Outdated
B = as_tensor_variable(b) | ||
|
||
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) | ||
output = b.type().astype(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is doing a symbolic cast, the dtype should be set on the type directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got an error that b.type
doesn't take dtype
argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't, b.type is a concrete type already, you call it to make a variable. You want to create a new type that's almost like b.type, but possibly with a different dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw this is not just nitpick, if the astype was actually needed make_node will fail, because it requires output variables to have no owner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do b.type.clone(dtype=out_dtype)()
instead I think
Benchmark after tuning up the
|
That looks much better! |
I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often |
100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever Edit: those are on my machine, don't know about yours |
This is the best I think we can get out of this in python? def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
kl = self.lower_diags
ku = self.upper_diags
if node.outputs[0].dtype == "float64":
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
else:
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
ab_size = kl + ku + 1
a_storage = storage_map[node.inputs[0]]
b_storage = storage_map[node.inputs[1]]
out_storage = storage_map[node.outputs[0]]
out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False]
def thunk(
a_storage=a_storage,
b_storage=b_storage,
out_storage=out_storage,
out_computed=out_computed,
kl=kl,
ku=ku,
ab_size=ab_size,
gbmv=gbmv,
):
A = a_storage[0]
b = b_storage[0]
m, n = A.shape
ab = np.zeros((ab_size, n), dtype=A.dtype, order="C")
for i, k in enumerate(range(ku, -kl - 1, -1)):
if k > 0:
ab[i, k:] = diag(A, k=k)
else:
ab[i, :n + k] = diag(A, k=k)
out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b)
out_computed[0] = True
return thunk |
A = as_tensor_variable(A) | ||
B = as_tensor_variable(b) | ||
|
||
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is wrong for integer types
I'm not saying we should do that, but it gives you a lower bound on what to expect from your micro-optimizations |
Here's what the thunk version benchmarks as for me:
I'm curious if it's possible to destroy A and make it into A_banded in-place. If it's possible, it doesn't seem trivial. BLAS doesn't have an Frankly my time would be better served thinking about how to do this in C at this point. |
Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time? |
Well SparseDot doesn't work with batch inputs, but I'm curious. Also I don't think the code is too complex or performing too bad. I don't agree with your sentiment, should be thinking of a C impl. A numba one is more interesting... |
Por que não os dois? Seriously though my feeling is that if we're putting this stuff into a PyMC model the code has to be ultra-performant. It's going to be called umptillion times, the inner-loop of a PDE solver times the MCMC loop. I'll work on the numba dispatch next at any rate |
By that argument you can't really add any specialized Op that doesn't have a C implementation (unless it's replacing an Op that also doesn't have C implementation). Ignoring the general user, you can have code to decide whether to use this Op or not based on the size (or a rewrite). Also how are you sampling / getting A, can you avoid the boxing/unboxing of the diagonals? |
well the point is the specialization isn't adding anything over good ol' |
Description
This PR adds a
BandedDot
Op
that usesgbmv
to do matrix-vector multiplication for the case that A is a banded matrix.In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:
I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.
Related Issue
linalg.BandedDot
#1415Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/