-
Notifications
You must be signed in to change notification settings - Fork 134
Reuse cholesky
decomposition with cho_solve
in graphs with multiple pt.solve
when assume_a = "pos"
#1467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reuse cholesky
decomposition with cho_solve
in graphs with multiple pt.solve
when assume_a = "pos"
#1467
Conversation
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (88.88%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1467 +/- ##
=======================================
Coverage 82.03% 82.03%
=======================================
Files 214 214
Lines 50398 50402 +4
Branches 8897 8899 +2
=======================================
+ Hits 41345 41349 +4
Misses 6848 6848
Partials 2205 2205
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a bit shame that you can't reuse the existing tests and just add an extra parametrization case
Just one question about the second argument to cholesky solve
@@ -49,11 +52,18 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: | |||
b_ndim=b_ndim, | |||
transposed=transposed, | |||
) | |||
elif assume_a == "pos": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's True in (A_decomp, True)
? Whether it's upper or lower? Don't we need to know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's the lower flag. I was thinking it doesn't matter because we will be adding in the decomposition ourselves via rewrite, so we control which one is done. I could respect the setting on the solve Op
if you think that's better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine, maybe add a comment for future devs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the transposed doesn't matter because it's symmetric?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly. I brought up that it doesn't have a flag because the nodes should still be merged right? Or do the inputs need to be the same as well? I was thinking cho_solve((A, False), b) and cho_solve((A, True), b) would be the same function (with different inputs ofc)
We could change (A, False) to (A.T, True), but then the inputs still aren't the same. The more I think about it, the more I believe we have to be respectful of the user's flags, in case only one half of A is being stored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the user isn't creating chol_factor nor the chol solve in these rewrites so it doesn't matter ever? Unless I'm missing something your first approach was correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably shouldn't even allow chol_factor, lower=True at the graph level, but always do upper and transpose if the user requested.
It's like the solve transposed, we handle the transpositions symbolically to keep less variations floating around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solve(A, b, lower=False, assume_a='pos')
will only ever look at the upper triangle of A to do the computation. So the user might pass in data that is structured in a special way, taking this into account (for example -- only storing half of the matrix in memory).
When we rewrite, if we choose to always use c_and_lower = (cholesky(A), True)
, regardless of what was requested, we are assuming that the A matrix is actually symmetrical. That assumption isn't consistent with what LAPACK actually requires, so it could lead to (silent!) incorrect computation.
I don't see any any downside to respecting what the user asked for in the rewrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just transpose A in that case. The issue is one of merging less scenarios. What happens now if user has a Solve(A, b1, lower=True), and another Solve(A.T, b2, lower=False).
Are we merging it here correctly? You were ignoring the transpose info coming from the rewrite that's used by the other solves.
That's what should determine the flag, not the original lower. Or the two together. Here we actually have two lowers, which one is used?
Note that if we never represented one of the forms our scenario simplifies.
I can re-use the tests by expanding the parametrization to include the count functions. Let me make a version that does that and you can decide if you like it. |
@ricardoV94 what do you think of this |
Nice |
Description
We now have some rewrite machinery that reuses LU decomposition in graphs that repeatedly solve against the same
A
matrix with differentb
. This comes up in the gradients ofpt.solve
(or any graph withpt.solve(A, b)
,pt.solve(A, c)
), as well asscan
andBlockwise
graphs.The same trick -- rewriting
pt.solve(A, b)
topt.lu_solve(pt.lu_factor(A), b)
-- can be done withcholesky
andcho_solve
. This PR extends the rewrites to cover this case. It also renames the rewrites to have more neutral names, since we no longer always LU decompose.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1467.org.readthedocs.build/en/1467/