Skip to content

Blockwise improvements #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 10, 2023
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 6, 2023

This PR adds a bunch of blockwise improvements, using test_batched_mvnormal_logp_and_dlogp as a benchmark.

TODO:

  • Direct test for Blockwise Alloc rewrite
  • Direct test for Blockwise AdvancedIncSubtensor rewrite
BEFORE:
------------------------------------------------------------------------------------------------------------------------- benchmark: 9 tests -------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                             Min                     Max                    Mean                 StdDev                  Median                     IQR            Outliers         OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:()]                       241.4050 (1.0)        1,810.0030 (1.0)          267.8159 (1.0)         112.2918 (1.0)          248.6980 (1.0)            6.1395 (1.0)        43;202  3,733.9078 (1.0)        1725           1
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:(1000,)]                7,727.6810 (32.01)     32,625.9060 (18.03)      9,140.6252 (34.13)     3,506.5375 (31.23)      8,056.3195 (32.39)        532.3450 (86.71)        8;14    109.4017 (0.03)        106           1
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:(4, 1000)]             36,519.7910 (151.28)    43,507.6490 (24.04)     37,502.4438 (140.03)    1,357.3882 (12.09)     37,078.4155 (149.09)       775.1030 (126.25)        1;1     26.6649 (0.01)         26           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:()]              204,345.6680 (846.48)   213,811.9620 (118.13)   209,570.1286 (782.52)    3,860.6208 (34.38)    211,285.3960 (849.57)     5,903.3477 (961.54)        2;0      4.7717 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:(1000,)]         206,162.0440 (854.01)   246,957.8760 (136.44)   220,043.3346 (821.62)   15,613.3438 (139.04)   215,733.2810 (867.45)    11,453.7725 (>1000.0)       1;1      4.5446 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:(4, 1000)]       344,049.8250 (>1000.0)  401,401.8120 (221.77)   371,411.4898 (>1000.0)  23,642.7591 (210.55)   364,708.9210 (>1000.0)   38,862.3615 (>1000.0)       2;0      2.6924 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:()]            791,623.1920 (>1000.0)  937,836.2080 (518.14)   842,729.1312 (>1000.0)  68,029.4537 (605.83)   798,030.7500 (>1000.0)  110,470.4235 (>1000.0)       1;0      1.1866 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:(4, 1000)]     806,370.6470 (>1000.0)  988,275.6840 (546.01)   896,315.0546 (>1000.0)  79,760.8230 (710.30)   928,718.6800 (>1000.0)  135,838.3053 (>1000.0)       2;0      1.1157 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:(1000,)]       820,059.1180 (>1000.0)  877,132.3050 (484.60)   848,327.1440 (>1000.0)  20,863.3722 (185.80)   844,725.1540 (>1000.0)   24,586.2695 (>1000.0)       2;0      1.1788 (0.00)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

AFTER:
------------------------------------------------------------------------------------------------------------------------- benchmark: 9 tests ------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                             Min                     Max                    Mean                 StdDev                  Median                    IQR            Outliers         OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:()]                       223.7400 (1.0)        1,970.9100 (1.0)          239.0524 (1.0)          76.9007 (1.0)          229.3510 (1.0)           3.1460 (1.0)        35;196  4,183.1832 (1.0)        1854           1
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:(1000,)]                  807.4360 (3.61)      19,562.5290 (9.93)         991.2214 (4.15)      1,043.0862 (13.56)        825.7610 (3.60)         22.8352 (7.26)       22;122  1,008.8563 (0.24)        779           1
test_batched_mvnormal_logp_and_dlogp[cov:()-mu:(4, 1000)]              2,792.7740 (12.48)     11,729.2920 (5.95)       3,331.1094 (13.93)       837.6202 (10.89)      3,153.6910 (13.75)       649.0978 (206.32)       14;8    300.2003 (0.07)        225           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:(1000,)]         128,476.3110 (574.22)   131,474.4790 (66.71)    130,098.2186 (544.22)    1,087.8881 (14.15)    129,534.9570 (564.79)    1,687.7140 (536.46)        3;0      7.6865 (0.00)          9           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:()]              128,885.2480 (576.05)   134,904.6290 (68.45)    131,217.4634 (548.91)    2,093.2579 (27.22)    130,645.1510 (569.63)    3,462.7852 (>1000.0)       3;0      7.6209 (0.00)          9           1
test_batched_mvnormal_logp_and_dlogp[cov:(1000,)-mu:(4, 1000)]       253,046.7460 (>1000.0)  267,180.3140 (135.56)   261,121.1530 (>1000.0)   6,518.4257 (84.76)    264,305.8970 (>1000.0)  11,556.8890 (>1000.0)       1;0      3.8296 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:(1000,)]       478,298.3400 (>1000.0)  539,055.1810 (273.51)   507,703.0552 (>1000.0)  25,177.5062 (327.40)   503,584.5180 (>1000.0)  42,617.3770 (>1000.0)       2;0      1.9697 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:()]            478,898.5630 (>1000.0)  488,528.9760 (247.87)   482,516.0360 (>1000.0)   3,638.3298 (47.31)    481,567.2750 (>1000.0)   3,815.3637 (>1000.0)       1;0      2.0725 (0.00)          5           1
test_batched_mvnormal_logp_and_dlogp[cov:(4, 1000)-mu:(4, 1000)]     479,454.1990 (>1000.0)  576,023.7790 (292.26)   513,581.6604 (>1000.0)  36,769.4058 (478.14)   501,011.6260 (>1000.0)  32,514.4020 (>1000.0)       1;0      1.9471 (0.00)          5           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

It also closes pymc-devs/pymc#7042
In my machine VI sampling goes down from 5m35s to 53s

@ricardoV94 ricardoV94 added graph rewriting performance compilation bug Something isn't working enhancement New feature or request labels Dec 6, 2023
@ricardoV94 ricardoV94 force-pushed the faster_blockwise_mvnormal branch 2 times, most recently from c49aeb2 to 29cd3ac Compare December 6, 2023 18:42
@ricardoV94 ricardoV94 mentioned this pull request Dec 7, 2023
8 tasks
@ricardoV94 ricardoV94 force-pushed the faster_blockwise_mvnormal branch 9 times, most recently from 7b58ddd to 44c0cfb Compare December 9, 2023 15:05
@ricardoV94 ricardoV94 marked this pull request as ready for review December 9, 2023 15:27
@codecov-commenter
Copy link

codecov-commenter commented Dec 9, 2023

Codecov Report

Merging #532 (505882e) into main (c49e395) will increase coverage by 0.04%.
Report is 1 commits behind head on main.
The diff coverage is 92.30%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #532      +/-   ##
==========================================
+ Coverage   80.85%   80.90%   +0.04%     
==========================================
  Files         162      162              
  Lines       46246    46393     +147     
  Branches    11305    11349      +44     
==========================================
+ Hits        37393    37535     +142     
- Misses       6631     6635       +4     
- Partials     2222     2223       +1     
Files Coverage Δ
pytensor/link/jax/dispatch/nlinalg.py 89.74% <100.00%> (-0.26%) ⬇️
pytensor/link/numba/dispatch/basic.py 86.18% <ø> (ø)
pytensor/tensor/basic.py 88.47% <100.00%> (+0.07%) ⬆️
pytensor/tensor/extra_ops.py 88.56% <100.00%> (+0.03%) ⬆️
pytensor/tensor/rewriting/basic.py 94.05% <100.00%> (-0.05%) ⬇️
pytensor/tensor/rewriting/blas.py 88.63% <100.00%> (+0.72%) ⬆️
pytensor/tensor/rewriting/linalg.py 83.78% <100.00%> (+2.53%) ⬆️
pytensor/tensor/shape.py 93.15% <100.00%> (+0.05%) ⬆️
pytensor/tensor/subtensor.py 89.68% <100.00%> (+0.09%) ⬆️
pytensor/graph/basic.py 89.08% <50.00%> (-0.24%) ⬇️
... and 5 more

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarks look amazing so obviously I support merging this. That said I don't think I understand the underlying problems well enough to say if this is the best possible way to go about things. The slowdowns were due to allocs being blockwise broadcast? Or because batch dims in matrix matmuls were being broadcast, resulting in huge arrays and slow matrix multiplications?

@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):

@numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be prange instead of range? I don't know if we allow parallel computation in compiled numba code, or if doing so would even be useful/interesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know either. Also I am not sure how well that would play with further multiprocessing by the outside. We have quite some issues with Blas in PyMC because of that.

Perhaps @aseyboldt can weigh in?

_ = extract_static_dim(x_sum_dim, y_sum_dim)
out_shape = (batch_dim, x_row_dim, y_col_dim)

# Change dtype if needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this unusual in pytensor code? Usually it refuses to quietly up/downcast? I'm mostly thinking about scan outputs_info as an example, but maybe it's not true in general.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An op is allowed to enforce whatever input/output types it wants. In this case I think it's a requirement for the C code

BOp = Blockwise(Op, signature="(x),(x)->(x)")
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that this function is the source of the major speedups in this PR -- I have it in my head that the slowdown was due to alloc operations being blockwise'd. Is that right? If so, why did it cause such terrible graphs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For one thing it was causing huge constants during constant fold.

But more importantly it was causing slow repeated Blockwise operations like Arange with the same inputs. Basically broadcasting before a blockwise means doing the same computation batched number of times which is just silly.

Remember that Blockwise does not have a C implementation so it's just a python loop in np.vectorize. We want to get rid of as many as possible which is what this PR mostly does. Any Blockwise we can avoid in the final function is a win

@@ -1880,76 +1873,26 @@ def contiguous(var, ndim):
)
contiguate = "\n".join(contiguate)

def c_dimshuffle(newname, oldname, shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were you able to chop all this out without replacing it? It is because the logic it handled will now be re-written before it ever ends up as C code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was able to chop it by making the Op only support 3d tensors as inputs and doing the dimhuffles and squeezes manually by the helper functions that calls this Op.

Since the C implementation added the dimhuffles in all cases it looked like an easy way to simplify C code

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 9, 2023

@jessegrabowski the biggest speedups come from:

  1. Avoiding Blockwise all together which doesn't have C implementation (nor numba, although that was not the driving reason) by implementing eager vectorize or, when not always possible, specialization rewrites into other equivalent graphs like matmul to batched_dot.

  2. Avoiding repeated computations on the Blockwise by pushing Alloc inputs to the outputs of Blockwise (think about doing expensive_op(alloc(x, 1000)) vs alloc(expensive_op(x), 1000) with the extra benefit that sometimes these allocs are not needed at all. This happens when there would be a subtensor operation later on, or they would be implicitly broadcasted by another input to an Elemwise or Blockwise that actually is "full rank".

@ricardoV94 ricardoV94 force-pushed the faster_blockwise_mvnormal branch from 44c0cfb to 8fafdf3 Compare December 10, 2023 10:50
Also extend eager rewrite to more Ops

The Blockwise MatrixInverse grad test became more sensitive in float32, because desired stabilization rewrites (mainly `inv_as_solve`) that target Dot of Blockwise{MatrixInverse} are now triggered in the default blockwise grad but not in the non-default non-blockwise grad
Also return matmul for respective vectorize of dot, to avoid creating redundant Blockwise Ops
@ricardoV94 ricardoV94 force-pushed the faster_blockwise_mvnormal branch from 8fafdf3 to db5a630 Compare December 10, 2023 10:52
It now supports an arbitrary number of batched dimensions of b, by raveling them together
The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
Also adds better static shapes
Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
@ricardoV94 ricardoV94 force-pushed the faster_blockwise_mvnormal branch from db5a630 to 505882e Compare December 10, 2023 11:04
@ricardoV94 ricardoV94 merged commit 68b41a4 into pymc-devs:main Dec 10, 2023
@ricardoV94 ricardoV94 mentioned this pull request Jan 5, 2024
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working compilation enhancement New feature or request graph rewriting performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: ADVI calculation is ~4x slower after 5.9.1
4 participants