Skip to content

Implement BandedDot Op #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 23, 2025

Description

This PR adds a BandedDot Op that uses gbmv to do matrix-vector multiplication for the case that A is a banded matrix.

In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:

------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------
Name (time in us)                       Min                    Max                  Mean              StdDev                Median                IQR            Outliers           OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_dot_perf[10]                    1.7500 (1.0)          17.3330 (1.0)          1.9054 (1.0)        0.1292 (1.0)          1.9160 (1.0)       0.0420 (1.0)      585;1740  524,831.2234 (1.0)       38401           1
test_banded_dot_perf[10]            19.9580 (11.40)    13,765.1250 (794.16)      32.5111 (17.06)    282.5468 (>1000.0)     20.5830 (10.74)     0.3750 (8.93)        6;349   30,758.7051 (0.06)       3275           1

test_dot_perf[100]                   2.4580 (1.40)         42.5420 (2.45)         2.7856 (1.46)       0.3265 (2.53)         2.7500 (1.44)      0.0420 (1.0)      343;7436  358,988.7425 (0.68)      71429           1
test_banded_dot_perf[100]           19.8330 (11.33)    15,203.3750 (877.13)      30.9185 (16.23)    193.8617 (>1000.0)     20.9580 (10.94)     0.4160 (9.90)      51;3057   32,343.1413 (0.06)      20566           1

test_dot_perf[1000]                 15.0000 (8.57)         61.5000 (3.55)        16.6383 (8.73)       1.4182 (10.98)       17.2920 (9.03)      2.2080 (52.57)     905;126   60,102.3508 (0.11)      18377           1
test_banded_dot_perf[1000]          27.0420 (15.45)       423.8750 (24.45)       32.9042 (17.27)      5.2005 (40.25)       32.6250 (17.03)     0.6250 (14.88)    129;1334   30,391.2634 (0.06)      12501           1

test_dot_perf[10_000]            3,369.4580 (>1000.0)   5,011.3330 (289.12)   3,412.7784 (>1000.0)  119.9981 (928.81)   3,394.5625 (>1000.0)  17.2910 (411.69)       4;25      293.0164 (0.00)        198           1
test_banded_dot_perf[10_000]       109.9170 (62.81)       611.5830 (35.28)      139.2751 (73.10)     52.3002 (404.81)     116.5000 (60.80)    14.0000 (333.33)    472;678    7,180.0341 (0.01)       3386           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/

@jessegrabowski jessegrabowski added enhancement New feature or request help wanted Extra attention is needed Op implementation linalg Linear algebra labels May 23, 2025
jessegrabowski and others added 2 commits May 23, 2025 17:32
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@jessegrabowski
Copy link
Member Author

I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests -------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                      Min                       Max                      Mean                  StdDev                    Median                     IQR            Outliers             OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot]                      541.9988 (1.0)          4,292.0001 (1.0)            638.1136 (1.0)           51.0902 (1.0)            625.0011 (1.0)           41.0000 (40.91)    1506;209  1,567,119.1257 (1.0)       15636           1
test_banded_dot_perf[10-banded_dot]            17,500.0005 (32.29)      418,167.0010 (97.43)       18,191.1183 (28.51)      3,829.7598 (74.96)       18,083.0011 (28.93)        167.0014 (166.62)     70;630     54,971.8815 (0.04)      11353           1

test_banded_dot_perf[100-dot]                   1,209.0004 (2.23)        23,959.0008 (5.58)         1,340.3628 (2.10)         103.1441 (2.02)         1,333.0009 (2.13)           1.0023 (1.0)    1217;34675    746,066.6804 (0.48)      88889           1
test_banded_dot_perf[100-banded_dot]           17,542.0009 (32.37)       77,083.9997 (17.96)       18,240.8191 (28.59)      1,230.1810 (24.08)       18,000.0006 (28.80)        250.0001 (249.44)   654;2431     54,822.0996 (0.03)      19018           1

test_banded_dot_perf[1000-dot]                 13,291.9995 (24.52)       49,874.9996 (11.62)       15,195.7498 (23.81)      1,137.7872 (22.27)       15,833.0004 (25.33)      1,832.9993 (>1000.0)  2954;119     65,807.8747 (0.04)      22347           1
test_banded_dot_perf[1000-banded_dot]          24,624.9983 (45.43)       74,874.9990 (17.45)       30,233.2753 (47.38)      1,347.0049 (26.37)       30,125.0002 (48.20)        375.0010 (374.15)   874;1333     33,076.1385 (0.02)      15595           1

test_banded_dot_perf[10_000-dot]            3,394,874.9988 (>1000.0)  5,084,541.9992 (>1000.0)  3,585,834.0104 (>1000.0)  191,227.5142 (>1000.0)  3,558,604.5005 (>1000.0)  199,729.5003 (>1000.0)      16;3        278.8752 (0.00)        192           1
test_banded_dot_perf[10_000-banded_dot]       105,208.0006 (194.11)     389,250.0008 (90.69)      124,879.6041 (195.70)    35,967.3472 (704.00)     110,375.0001 (176.60)     8,343.4998 (>1000.0)   320;440      8,007.7128 (0.01)       2665           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Comment on lines 1690 to 1699
A = np.asarray(A)
m, n = A.shape
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C")

for i, k in enumerate(range(ku, -kl - 1, -1)):
padding = (k, 0) if k >= 0 else (0, -k)
diag = np.pad(np.diag(A, k=k), padding)
ab[i, :] = diag

return ab
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this explains most of the python overhead for small cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one way or another we have to do that though as part of the cost of the Op. Unless we demand users have inputs ready in that form.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's fine, I was just thinking out loud.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rearrangement could be done symbolically in a wrapper Op that calls the blas Op (which expects things to be ready in the correct form)

It might also be better to do smart column indexing on ab instead of using pad

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's similar to the Solve, in that you can also do it once and reuse many times possibly, but I think that's too much micro-optimization for now. We also don't want to autodiff through it

Comment on lines 1702 to 1703
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause import time overhead to PyTensor.

I'm okay paying the extra 3us at runtime instead since virtually nobody will ever use this (or use it in a case where they need those extra us)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this as well. It won't stay in the final verison.

Copy link
Member

@ricardoV94 ricardoV94 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can exploit prepare_node and add the function to node.tag, which the perform method can then retrieve from. That's two attribute accesses instead of a string check / scipy caching...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can sidestep perform and use make_thunk instead

@ricardoV94
Copy link
Member

I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big

@jessegrabowski
Copy link
Member Author

Benchmark after tuning up the _to_banded_form function:

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                      Min                       Max                      Mean                 StdDev                    Median                     IQR            Outliers             OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot]                      499.9965 (1.0)         55,500.0006 (1.41)           665.4888 (1.0)         390.9718 (1.0)            666.0011 (1.0)           42.0005 (1.00)      31;2639  1,502,654.9287 (1.0)       32129           1
test_banded_dot_perf[10-banded_dot]             2,832.9996 (5.67)        71,957.9984 (1.82)         3,356.9474 (5.04)        782.8860 (2.00)         3,332.9998 (5.00)         332.9988 (7.93)    1874;2239    297,889.6806 (0.20)      32833           1

test_banded_dot_perf[100-dot]                   1,000.0003 (2.00)        58,208.9997 (1.47)         1,191.9862 (1.79)        396.5918 (1.01)         1,166.9981 (1.75)          41.9968 (1.0)      305;3163    838,935.8643 (0.56)      91258           1
test_banded_dot_perf[100-banded_dot]            3,332.9998 (6.67)        39,499.9988 (1.0)          3,874.8349 (5.82)        471.5917 (1.21)         3,875.0004 (5.82)          84.0009 (2.00)   1020;11972    258,075.5142 (0.17)      71008           1

test_banded_dot_perf[1000-dot]                 13,584.0019 (27.17)      118,374.9991 (3.00)        16,143.5130 (24.26)     1,984.1144 (5.07)        16,291.0001 (24.46)      2,042.0011 (48.62)    1390;171     61,944.3861 (0.04)      14202           1
test_banded_dot_perf[1000-banded_dot]           8,167.0005 (16.33)       68,749.9996 (1.74)        10,694.7895 (16.07)     1,131.4230 (2.89)        11,000.0001 (16.52)        416.9997 (9.93)    6811;7582     93,503.4764 (0.06)      32521           1

test_banded_dot_perf[10_000-dot]            3,379,415.9972 (>1000.0)  3,680,959.0019 (93.19)    3,463,207.0645 (>1000.0)  79,485.8545 (203.30)   3,434,124.9993 (>1000.0)  114,541.9992 (>1000.0)       6;0        288.7497 (0.00)         31           1
test_banded_dot_perf[10_000-banded_dot]        93,582.9994 (187.17)     294,458.0010 (7.45)       100,154.2338 (150.50)   22,660.4163 (57.96)       95,479.0012 (143.36)     2,083.4996 (49.61)       10;27      9,984.6004 (0.01)        248           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@ricardoV94
Copy link
Member

That looks much better!

@jessegrabowski
Copy link
Member Author

I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2025

100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever

Edit: those are on my machine, don't know about yours

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2025

This is the best I think we can get out of this in python?

    def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
        kl = self.lower_diags
        ku = self.upper_diags
        if node.outputs[0].dtype == "float64":
            gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
        else:
            gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")

        ab_size = kl + ku + 1
        a_storage = storage_map[node.inputs[0]]
        b_storage = storage_map[node.inputs[1]]
        out_storage = storage_map[node.outputs[0]]
        out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False]
        def thunk(
            a_storage=a_storage,
            b_storage=b_storage,
            out_storage=out_storage,
            out_computed=out_computed,
            kl=kl,
            ku=ku,
            ab_size=ab_size,
            gbmv=gbmv,
        ):
            A = a_storage[0]
            b = b_storage[0]
            m, n = A.shape

            ab = np.zeros((ab_size, n), dtype=A.dtype, order="C")
            for i, k in enumerate(range(ku, -kl - 1, -1)):
                if k > 0:
                    ab[i, k:] = diag(A, k=k)
                else:
                    ab[i, :n + k] = diag(A, k=k)

            out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b)
            out_computed[0] = True

        return thunk

A = as_tensor_variable(A)
B = as_tensor_variable(b)

out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is wrong for integer types

@ricardoV94
Copy link
Member

That's much more palatable.

The difference between numba/python gbmv is also what you should expect to see if you implemented gbmv in C so you don't have to wonder.

@jessegrabowski
Copy link
Member Author

image

Thinking emoji

@jessegrabowski
Copy link
Member Author

The problem in the timings was some copies being done both in python and numba mode. Here are the updated timings. They're essentially the same except on the low-end, where getting rid of the python overhead is giving numba a small consistent speed bump.

image
image

@jessegrabowski jessegrabowski marked this pull request as ready for review May 24, 2025 13:03
@jessegrabowski
Copy link
Member Author

I'd like to call this one done for now, although there are three major things that are left to do:

  1. Enable GEMV rewrites in NUMBA and re-use that machinery to allow all arguments to the numba xgemv fuction. Right now I'm forcing alpha=1, beta=0.
  2. Split off the code that converts a dense banded matrix into the banded matrix form into a separate Op. Then we can add a rewrite to do things like lift that outside of scan, for example. More importantly, we can;
  3. Introduce a rewrite that converts GEMV(BandedMatrix(A), x, ...) into BandedGEMV(BandedMatrix(A), x, ...). The existing BandedDot can become BandedGEMV and we can use all arguments.

I want to merge this then do these 3 things because I want to do #1418 first, and put the resulting function into the new _BLAS.py file in this PR. Enable the relevant rewrites, then revisit this code.

I also need to think about how to handle the splitting out of the BandedMatrix Op, because it destroys information about how many rows the input matrix has (gemv needs to know this).

Copy link

codecov bot commented May 24, 2025

Codecov Report

Attention: Patch coverage is 72.72727% with 33 lines in your changes missing coverage. Please review.

Project coverage is 82.09%. Comparing base (261aaf3) to head (481814f).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/linalg/dot/banded.py 46.93% 26 Missing ⚠️
pytensor/tensor/slinalg.py 90.24% 2 Missing and 2 partials ⚠️
pytensor/link/numba/dispatch/slinalg.py 75.00% 2 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (72.72%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1416      +/-   ##
==========================================
- Coverage   82.11%   82.09%   -0.02%     
==========================================
  Files         211      213       +2     
  Lines       49686    49843     +157     
  Branches     8813     8827      +14     
==========================================
+ Hits        40798    40920     +122     
- Misses       6710     6740      +30     
- Partials     2178     2183       +5     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 79.54% <ø> (ø)
pytensor/link/numba/dispatch/linalg/_BLAS.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/slinalg.py 70.10% <75.00%> (+0.34%) ⬆️
pytensor/tensor/slinalg.py 93.00% <90.24%> (-0.18%) ⬇️
pytensor/link/numba/dispatch/linalg/dot/banded.py 46.93% <46.93%> (ø)

... and 6 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

A = as_tensor_variable(A)
x = as_tensor_variable(x)

out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong for integers/should raise. Also reject complex?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from other make_node in slinalg (eigvalsh, eigvalsh grad, solve lyapunov stuff). What's the right way to upcast here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The right way is to predict what scipy outputs. Some Ops are lazy and just call scipy with a minimal input case to find out the output type. I don't love that.

Which makes me wonder I guess numba/direct call to xbmv doesn't work with integers arrays, so we may need to cast/raise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does JAX do on integer inputs?

Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does JAX do on integer inputs?

No idea, cast them to float or call a dot function that works on integers?

Also it's not that onerous to just try every combination of input pairs on the scipy function, write it in a dictionary, and just look it up. Is that too crazy?

I think it's a bit crazy, you could add a function with lru_cache on the dtypes, that tries it and stores the result. Most combinations will never be needed. And we don't want to do it at import time


def make_node(self, A, x):
A = as_tensor_variable(A)
x = as_tensor_variable(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise ValueError for non core ndims

@@ -1669,6 +1670,73 @@ def block_diag(*matrices: TensorVariable):
return _block_diagonal_matrix(*matrices)


class BandedDot(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in blas.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your message, fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in pytensor.tensor.blas ? I can do that if you think it's better

KU = val_to_int_ptr(ku)

ALPHA = np.array(1.0, dtype=dtype)
INCX = val_to_int_ptr(x.strides[0] // x.itemsize)
Copy link
Member

@ricardoV94 ricardoV94 May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test non-unit positive and negative strides for x. In C Gemv for instance we need to point to the last memory position when strides is negative

Copy link
Member

@ricardoV94 ricardoV94 May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can, but need not also test for A, y, strides. Since we're creating them now ourselves we know they're always correct. But once we split the Op we will need to worry about those as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lmk what you think about the way I'm testing strides now, and I can expand it if it's adequate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unresolving this because the negative stride tests are failing

Copy link
Member

@ricardoV94 ricardoV94 May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's like the Cblas, when you have negative strides, you have to point to the end of the numpy array (x[-1]). Blas wants to know where the block of memory starts, even if it iterates in reverse, but numpy points to the end of the array when it has negative strides.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;

@jessegrabowski
Copy link
Member Author

@ricardoV94 since #1418 got resolved without adding the GEMV rewrite to numba, how should I handle expanding this Op to include rank-1 updates?

@ricardoV94
Copy link
Member

ricardoV94 commented May 28, 2025

We may still link directly to blas for the full update, not sure numba does it besides dispatching the matrix/vector dot part

@ricardoV94
Copy link
Member

ricardoV94 commented May 28, 2025

I would start by benchmarking directly with numba to see if we get a speedup from calling the fused gemv op directly or if numba does it (the regular one, it for sure doesn't do it for gbmv)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed linalg Linear algebra Op implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement linalg.BandedDot
2 participants