Skip to content

Fix shape issues in jax tridiagonal solve #1414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

benmaier
Copy link

@benmaier benmaier commented May 22, 2025

Description

In the tridiagonal solve dispatch for jax, dl and du should have same shape as d, and b should have rank2. This PR:

  • padded dl and du so they have the same shape as d
  • check whether b is a vector -- if so, add a dimension that makes the vector a column vector, compute the result and cast back to vector

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1414.org.readthedocs.build/en/1414/

@jessegrabowski jessegrabowski changed the title Fix shape issues in jax tridiagonal solve (dl and du should have same shape as d and b should have rank2); close #1413 Fix shape issues in jax tridiagonal solve May 23, 2025
Copy link

codecov bot commented May 23, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.14%. Comparing base (4829455) to head (d0239a3).
Report is 8 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1414      +/-   ##
==========================================
+ Coverage   82.12%   82.14%   +0.01%     
==========================================
  Files         211      211              
  Lines       49687    49695       +8     
  Branches     8813     8815       +2     
==========================================
+ Hits        40807    40821      +14     
+ Misses       6702     6697       -5     
+ Partials     2178     2177       -1     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/slinalg.py 93.97% <100.00%> (+8.64%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1))
# if b is a vector, broadcast it to be a matrix
b_is_vec = len(b.shape) == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to check this at runtime. The Solve Op has a property b_ndim, so you can do:

b_is_vec = op.b_ndim

if assume_a == 'tridiagonal':
    ... # carry on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check as written will also fail in the batched case (that's why we have it at the Op level)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'd have to do b_is_vec = op.b_ndim == 1 though, no? because bool(op.b_ndim) -> True for b_ndim > 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, my code has an error

benmaier and others added 2 commits May 23, 2025 15:26
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
),
(4, 5, 5),
2,
),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what ruff came up with..

@benmaier benmaier requested a review from jessegrabowski May 23, 2025 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: Error in shape for JAX tridiagonal solve inputs
3 participants