-
Notifications
You must be signed in to change notification settings - Fork 132
Fix lu_solve with batch inputs #1394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3098cc7
to
b8a1ede
Compare
b8a1ede
to
de031b7
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1394 +/- ##
=======================================
Coverage 82.02% 82.02%
=======================================
Files 207 207
Lines 49294 49301 +7
Branches 8746 8747 +1
=======================================
+ Hits 40433 40440 +7
Misses 6695 6695
Partials 2166 2166
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses a bug in lu_solve when handling batch inputs and cleans up test function calls. The changes include replacing all instances of “pytensor.function” with “function” in the tests and refactoring lu_solve into a private _lu_solve that is then vectorized via a new public lu_solve.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/tensor/test_slinalg.py | Updated function calls and added a new test for lu_solve with batch dimensions. |
pytensor/tensor/slinalg.py | Refactored lu_solve to support vectorized batch operations and updated imports. |
b = pt.tensor("b", shape=(1, 4, 5)) | ||
lu_and_pivots = lu_factor(A) | ||
x = lu_solve(lu_and_pivots, b, b_ndim=1) | ||
assert x.type.shape in {(3, 4, None), (3, 4, 5)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The assertion on x.type.shape uses a set containing None, which may be brittle or ambiguous. Consider clarifying the expected shape explicitly to improve test robustness.
assert x.type.shape in {(3, 4, None), (3, 4, 5)} | |
assert x.type.shape == (3, 4, 5) or (x.type.shape[:2] == (3, 4) and x.type.shape[2] is None) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was on purpose, right now it returns None, but if we improve static type it should return 5 and won't make the test fail
@@ -508,8 +507,8 @@ def test_infer_shape(self): | |||
A = matrix() | |||
b = matrix() | |||
self._compile_and_check( | |||
[A, b], # pytensor.function inputs | |||
[self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs | |||
[A, b], # function inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These comments are silly anyway, just use a keyword argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a find replace spill over to comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments are still stupid >:(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I didn't add them :D
Closes #1376
📚 Documentation preview 📚: https://pytensor--1394.org.readthedocs.build/en/1394/