Description
Description
Benchmarking in #1323 showed that the banded (tridiagonal) case can get huge speedups by using a specialized dot product. This issue asks for a BandedDot
Op
that uses xgbmv
to realize these speedups.
In the future it would be nice to be able to rewrite into this Op
in cases were we see that we can, but I don't think it's necessary on first pass. Just having the functionality laying around will be nice.
Note that JAX doesn't use xgbmv
to do this in the tridiagonal case. They have _tridiagonal_product
that just directly does it using jax primitive Ops. This might be preferable, because it would require no extra dispatch work, but it would not let us handle the general banded case -- only the tridiagonal case. Maybe we want both?
At minimum, we should benchmark xgbmv
vs direct method in the tridiagonal case.