-
Notifications
You must be signed in to change notification settings - Fork 131
Fix shape issues in jax tridiagonal solve #1414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1414 +/- ##
==========================================
+ Coverage 82.12% 82.14% +0.01%
==========================================
Files 211 211
Lines 49687 49695 +8
Branches 8813 8815 +2
==========================================
+ Hits 40807 40821 +14
+ Misses 6702 6697 -5
+ Partials 2178 2177 -1
🚀 New features to boost your workflow:
|
dl = jax.numpy.pad(dl, (1, 0)) | ||
du = jax.numpy.pad(du, (0, 1)) | ||
# if b is a vector, broadcast it to be a matrix | ||
b_is_vec = len(b.shape) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to check this at runtime. The Solve
Op
has a property b_ndim
, so you can do:
b_is_vec = op.b_ndim
if assume_a == 'tridiagonal':
... # carry on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check as written will also fail in the batched case (that's why we have it at the Op level)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you'd have to do b_is_vec = op.b_ndim == 1
though, no? because bool(op.b_ndim) -> True
for b_ndim > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly, my code has an error
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
), | ||
(4, 5, 5), | ||
2, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what ruff came up with..
Description
In the tridiagonal solve dispatch for jax, dl and du should have same shape as d, and b should have rank2. This PR:
dl
anddu
so they have the same shape asd
b
is a vector -- if so, add a dimension that makes the vector a column vector, compute the result and cast back to vectorRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1414.org.readthedocs.build/en/1414/