-
Notifications
You must be signed in to change notification settings - Fork 135
Blockwise improvements #532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blockwise improvements #532
Conversation
c49aeb2
to
29cd3ac
Compare
7b58ddd
to
44c0cfb
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #532 +/- ##
==========================================
+ Coverage 80.85% 80.90% +0.04%
==========================================
Files 162 162
Lines 46246 46393 +147
Branches 11305 11349 +44
==========================================
+ Hits 37393 37535 +142
- Misses 6631 6635 +4
- Partials 2222 2223 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmarks look amazing so obviously I support merging this. That said I don't think I understand the underlying problems well enough to say if this is the best possible way to go about things. The slowdowns were due to allocs being blockwise broadcast? Or because batch dims in matrix matmuls were being broadcast, resulting in huge arrays and slow matrix multiplications?
@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs): | |||
|
|||
@numba_njit | |||
def batched_dot(x, y): | |||
# Numba does not support 3D matmul | |||
# https://github.com/numba/numba/issues/3804 | |||
shape = x.shape[:-1] + y.shape[2:] | |||
z0 = np.empty(shape, dtype=dtype) | |||
for i in range(z0.shape[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be prange instead of range? I don't know if we allow parallel computation in compiled numba code, or if doing so would even be useful/interesting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know either. Also I am not sure how well that would play with further multiprocessing by the outside. We have quite some issues with Blas in PyMC because of that.
Perhaps @aseyboldt can weigh in?
_ = extract_static_dim(x_sum_dim, y_sum_dim) | ||
out_shape = (batch_dim, x_row_dim, y_col_dim) | ||
|
||
# Change dtype if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this unusual in pytensor code? Usually it refuses to quietly up/downcast? I'm mostly thinking about scan outputs_info as an example, but maybe it's not true in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An op is allowed to enforce whatever input/output types it wants. In this case I think it's a requirement for the C code
BOp = Blockwise(Op, signature="(x),(x)->(x)") | ||
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5) | ||
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5) | ||
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing that this function is the source of the major speedups in this PR -- I have it in my head that the slowdown was due to alloc operations being blockwise'd. Is that right? If so, why did it cause such terrible graphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For one thing it was causing huge constants during constant fold.
But more importantly it was causing slow repeated Blockwise operations like Arange with the same inputs. Basically broadcasting before a blockwise means doing the same computation batched number of times which is just silly.
Remember that Blockwise does not have a C implementation so it's just a python loop in np.vectorize. We want to get rid of as many as possible which is what this PR mostly does. Any Blockwise we can avoid in the final function is a win
@@ -1880,76 +1873,26 @@ def contiguous(var, ndim): | |||
) | |||
contiguate = "\n".join(contiguate) | |||
|
|||
def c_dimshuffle(newname, oldname, shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were you able to chop all this out without replacing it? It is because the logic it handled will now be re-written before it ever ends up as C code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was able to chop it by making the Op only support 3d tensors as inputs and doing the dimhuffles and squeezes manually by the helper functions that calls this Op.
Since the C implementation added the dimhuffles in all cases it looked like an easy way to simplify C code
@jessegrabowski the biggest speedups come from:
|
Rewrite was already tagged as "shape_unsafe"
44c0cfb
to
8fafdf3
Compare
Also extend eager rewrite to more Ops The Blockwise MatrixInverse grad test became more sensitive in float32, because desired stabilization rewrites (mainly `inv_as_solve`) that target Dot of Blockwise{MatrixInverse} are now triggered in the default blockwise grad but not in the non-default non-blockwise grad
Also return matmul for respective vectorize of dot, to avoid creating redundant Blockwise Ops
8fafdf3
to
db5a630
Compare
It now supports an arbitrary number of batched dimensions of b, by raveling them together
The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
Also adds better static shapes
Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
db5a630
to
505882e
Compare
This PR adds a bunch of blockwise improvements, using
test_batched_mvnormal_logp_and_dlogp
as a benchmark.TODO:
It also closes pymc-devs/pymc#7042
In my machine VI sampling goes down from 5m35s to 53s