Skip to content

Fix shape issues in jax tridiagonal solve #1414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def cholesky(a, lower=lower):
def jax_funcify_Solve(op, **kwargs):
assume_a = op.assume_a
lower = op.lower
b_is_vec = op.b_ndim == 1

if assume_a == "tridiagonal":
# jax.scipy.solve does not yet support tridiagonal matrices
Expand All @@ -54,7 +55,20 @@ def solve(a, b):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)

# jax requires dl and du to have the same shape as d
dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1))

if b_is_vec:
b = jax.numpy.expand_dims(b, -1)

res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)

if b_is_vec:
return jax.numpy.squeeze(res, -1)

return res

else:
if assume_a not in ("gen", "sym", "her", "pos"):
Expand Down
48 changes: 48 additions & 0 deletions tests/link/jax/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,54 @@ def test_jax_solve():
)


@pytest.mark.parametrize(
"A_size, b_size, b_ndim",
[
(
(5, 5),
(5,),
1,
),
(
(5, 5),
(5, 1),
2,
),
(
(5, 5),
(1, 5),
1,
),
(
(4, 5, 5),
(4, 5, 5),
2,
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what ruff came up with..

],
ids=["basic_vector", "basic_matrix", "vector_broadcasted", "fully_batched"],
)
def test_jax_tridiagonal_solve(A_size: tuple, b_size: tuple, b_ndim: int):
A = pt.tensor("A", shape=A_size)
b = pt.tensor("b", shape=b_size)

out = pt.linalg.solve(A, b, assume_a="tridiagonal", b_ndim=b_ndim)

A_val = np.zeros(A_size)
N = A_size[-1]
A_val[...] = np.eye(N)
for i in range(N - 1):
A_val[..., i, i + 1] = np.random.randn()
A_val[..., i + 1, i] = np.random.randn()

b_val = np.random.randn(*b_size)

compare_jax_and_py(
[A, b],
[out],
[A_val, b_val],
)


def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())

Expand Down