Skip to content

Implement linalg.BandedDot #1415

Open
@jessegrabowski

Description

@jessegrabowski

Description

Benchmarking in #1323 showed that the banded (tridiagonal) case can get huge speedups by using a specialized dot product. This issue asks for a BandedDot Op that uses xgbmv to realize these speedups.

In the future it would be nice to be able to rewrite into this Op in cases were we see that we can, but I don't think it's necessary on first pass. Just having the functionality laying around will be nice.

Note that JAX doesn't use xgbmv to do this in the tridiagonal case. They have _tridiagonal_product that just directly does it using jax primitive Ops. This might be preferable, because it would require no extra dispatch work, but it would not let us handle the general banded case -- only the tridiagonal case. Maybe we want both?

At minimum, we should benchmark xgbmv vs direct method in the tridiagonal case.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions