Skip to content

Reuse cholesky decomposition with cho_solve in graphs with multiple pt.solve when assume_a = "pos" #1467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 12, 2025

Description

We now have some rewrite machinery that reuses LU decomposition in graphs that repeatedly solve against the same A matrix with different b. This comes up in the gradients of pt.solve (or any graph with pt.solve(A, b), pt.solve(A, c)), as well as scan and Blockwise graphs.

The same trick -- rewriting pt.solve(A, b) to pt.lu_solve(pt.lu_factor(A), b) -- can be done with cholesky and cho_solve. This PR extends the rewrites to cover this case. It also renames the rewrites to have more neutral names, since we no longer always LU decompose.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1467.org.readthedocs.build/en/1467/

Copy link

codecov bot commented Jun 13, 2025

Codecov Report

Attention: Patch coverage is 88.88889% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.03%. Comparing base (5f5be92) to head (dd6e0c3).

Files with missing lines Patch % Lines
pytensor/tensor/_linalg/solve/rewriting.py 88.88% 0 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (88.88%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1467   +/-   ##
=======================================
  Coverage   82.03%   82.03%           
=======================================
  Files         214      214           
  Lines       50398    50402    +4     
  Branches     8897     8899    +2     
=======================================
+ Hits        41345    41349    +4     
  Misses       6848     6848           
  Partials     2205     2205           
Files with missing lines Coverage Δ
pytensor/tensor/_linalg/solve/rewriting.py 93.70% <88.88%> (+0.20%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, a bit shame that you can't reuse the existing tests and just add an extra parametrization case

Just one question about the second argument to cholesky solve

@@ -49,11 +52,18 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim,
transposed=transposed,
)
elif assume_a == "pos":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's True in (A_decomp, True)? Whether it's upper or lower? Don't we need to know?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the lower flag. I was thinking it doesn't matter because we will be adding in the decomposition ourselves via rewrite, so we control which one is done. I could respect the setting on the solve Op if you think that's better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, maybe add a comment for future devs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the transposed doesn't matter because it's symmetric?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly

Copy link
Member Author

@jessegrabowski jessegrabowski Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly. I brought up that it doesn't have a flag because the nodes should still be merged right? Or do the inputs need to be the same as well? I was thinking cho_solve((A, False), b) and cho_solve((A, True), b) would be the same function (with different inputs ofc)

We could change (A, False) to (A.T, True), but then the inputs still aren't the same. The more I think about it, the more I believe we have to be respectful of the user's flags, in case only one half of A is being stored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the user isn't creating chol_factor nor the chol solve in these rewrites so it doesn't matter ever? Unless I'm missing something your first approach was correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't even allow chol_factor, lower=True at the graph level, but always do upper and transpose if the user requested.

It's like the solve transposed, we handle the transpositions symbolically to keep less variations floating around.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solve(A, b, lower=False, assume_a='pos') will only ever look at the upper triangle of A to do the computation. So the user might pass in data that is structured in a special way, taking this into account (for example -- only storing half of the matrix in memory).

When we rewrite, if we choose to always use c_and_lower = (cholesky(A), True), regardless of what was requested, we are assuming that the A matrix is actually symmetrical. That assumption isn't consistent with what LAPACK actually requires, so it could lead to (silent!) incorrect computation.

I don't see any any downside to respecting what the user asked for in the rewrite.

Copy link
Member

@ricardoV94 ricardoV94 Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just transpose A in that case. The issue is one of merging less scenarios. What happens now if user has a Solve(A, b1, lower=True), and another Solve(A.T, b2, lower=False).

Are we merging it here correctly? You were ignoring the transpose info coming from the rewrite that's used by the other solves.

That's what should determine the flag, not the original lower. Or the two together. Here we actually have two lowers, which one is used?

Note that if we never represented one of the forms our scenario simplifies.

@jessegrabowski
Copy link
Member Author

I can re-use the tests by expanding the parametrization to include the count functions. Let me make a version that does that and you can decide if you like it.

@jessegrabowski
Copy link
Member Author

@ricardoV94 what do you think of this

@ricardoV94
Copy link
Member

Nice

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jun 13, 2025
@jessegrabowski jessegrabowski merged commit 862c416 into pymc-devs:main Jun 13, 2025
71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Solve to ChoSolve optimizations
2 participants