From a95e1e181a835ae7b7bab5f5527d25b16c139118 Mon Sep 17 00:00:00 2001 From: "Ben F. Maier" Date: Fri, 23 May 2025 00:09:06 +0200 Subject: [PATCH 1/5] fix shape issues in jax tridiagonal solve; close #1413 --- pytensor/link/jax/dispatch/slinalg.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 3d6af00011..f91d03d1ab 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -54,7 +54,21 @@ def solve(a, b): dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1) d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1) du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1) - return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower) + # jax requires dl and du to have the same shape as d + dl = jax.numpy.pad(dl, (1, 0)) + du = jax.numpy.pad(du, (0, 1)) + # if b is a vector, broadcast it to be a matrix + b_is_vec = len(b.shape) == 1 + if b_is_vec: + b = jax.numpy.expand_dims(b, -1) + + res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b) + + if b_is_vec: + # if b is a vector, return a vector + return res.flatten() + else: + return res else: if assume_a not in ("gen", "sym", "her", "pos"): From d0239a3b16d7b07aa1b4297bac83404060f8dee4 Mon Sep 17 00:00:00 2001 From: "Ben F. Maier" Date: Fri, 23 May 2025 00:28:57 +0200 Subject: [PATCH 2/5] added tests for tridiagonal solve --- tests/link/jax/test_slinalg.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index ca944221aa..27e985433d 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -122,6 +122,38 @@ def test_jax_solve(): ) +def test_jax_tridiagonal_solve(): + N = 10 + A = pt.matrix("A", shape=(N, N)) + b = pt.vector("b", shape=(N,)) + + out = pt.linalg.solve(A, b, assume_a="tridiagonal") + + A_val = np.eye(N) + for i in range(N - 1): + A_val[i, i + 1] = np.random.randn() + A_val[i + 1, i] = np.random.randn() + + b_val = np.random.randn(N) + + compare_jax_and_py( + [A, b], + [out], + [A_val, b_val], + ) + + b_ = pt.matrix("b", shape=(N, 2)) + + out = pt.linalg.solve(A, b_, assume_a="tridiagonal") + b_val = np.random.randn(N, 2) + + compare_jax_and_py( + [A, b_], + [out], + [A_val, b_val], + ) + + def test_jax_SolveTriangular(): rng = np.random.default_rng(utt.fetch_seed()) From a0935507c8f808069259c913f988436548a56c51 Mon Sep 17 00:00:00 2001 From: "Benjamin F. Maier" Date: Fri, 23 May 2025 15:26:34 +0200 Subject: [PATCH 3/5] Update pytensor/link/jax/dispatch/slinalg.py Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/jax/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index f91d03d1ab..4df28dc4c7 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -57,7 +57,7 @@ def solve(a, b): # jax requires dl and du to have the same shape as d dl = jax.numpy.pad(dl, (1, 0)) du = jax.numpy.pad(du, (0, 1)) - # if b is a vector, broadcast it to be a matrix + # jax also requires b to be a matrix; reshape it to be a column vector if necessary b_is_vec = len(b.shape) == 1 if b_is_vec: b = jax.numpy.expand_dims(b, -1) From 6924f4e2f011a1ff1c7660f61bb7372d12033d4b Mon Sep 17 00:00:00 2001 From: "Ben F. Maier" Date: Fri, 23 May 2025 16:21:22 +0200 Subject: [PATCH 4/5] incorporate changes as asked for --- pytensor/link/jax/dispatch/slinalg.py | 12 ++--- tests/link/jax/test_slinalg.py | 69 +++++++++++++++++++-------- 2 files changed, 55 insertions(+), 26 deletions(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 4df28dc4c7..855052b124 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -46,6 +46,7 @@ def cholesky(a, lower=lower): def jax_funcify_Solve(op, **kwargs): assume_a = op.assume_a lower = op.lower + b_is_vec = op.b_ndim == 1 if assume_a == "tridiagonal": # jax.scipy.solve does not yet support tridiagonal matrices @@ -54,21 +55,20 @@ def solve(a, b): dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1) d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1) du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1) + # jax requires dl and du to have the same shape as d dl = jax.numpy.pad(dl, (1, 0)) du = jax.numpy.pad(du, (0, 1)) - # jax also requires b to be a matrix; reshape it to be a column vector if necessary - b_is_vec = len(b.shape) == 1 + if b_is_vec: b = jax.numpy.expand_dims(b, -1) res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b) if b_is_vec: - # if b is a vector, return a vector - return res.flatten() - else: - return res + return jax.numpy.squeeze(res, -1) + + return res else: if assume_a not in ("gen", "sym", "her", "pos"): diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 27e985433d..7f446a6b6d 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -122,19 +122,59 @@ def test_jax_solve(): ) -def test_jax_tridiagonal_solve(): - N = 10 - A = pt.matrix("A", shape=(N, N)) - b = pt.vector("b", shape=(N,)) +@pytest.mark.parametrize( + "A_size, b_size, b_ndim", + [ + ( + ( + 5, + 5, + ), + (5,), + 1, + ), + ( + ( + 5, + 5, + ), + (5, 1), + 2, + ), + ( + ( + 5, + 5, + ), + (1, 5), + 1, + ), + ( + ( + 4, + 5, + 5, + ), + (4, 5, 5), + 2, + ), + ], + ids=["basic_vector", "basic_matrix", "vector_broadcasted", "fully_batched"], +) +def test_jax_tridiagonal_solve(A_size: tuple, b_size: tuple, b_ndim: int): + A = pt.tensor("A", shape=A_size) + b = pt.tensor("b", shape=b_size) - out = pt.linalg.solve(A, b, assume_a="tridiagonal") + out = pt.linalg.solve(A, b, assume_a="tridiagonal", b_ndim=b_ndim) - A_val = np.eye(N) + A_val = np.zeros(A_size) + N = A_size[-1] + A_val[...] = np.eye(N) for i in range(N - 1): - A_val[i, i + 1] = np.random.randn() - A_val[i + 1, i] = np.random.randn() + A_val[..., i, i + 1] = np.random.randn() + A_val[..., i + 1, i] = np.random.randn() - b_val = np.random.randn(N) + b_val = np.random.randn(*b_size) compare_jax_and_py( [A, b], @@ -142,17 +182,6 @@ def test_jax_tridiagonal_solve(): [A_val, b_val], ) - b_ = pt.matrix("b", shape=(N, 2)) - - out = pt.linalg.solve(A, b_, assume_a="tridiagonal") - b_val = np.random.randn(N, 2) - - compare_jax_and_py( - [A, b_], - [out], - [A_val, b_val], - ) - def test_jax_SolveTriangular(): rng = np.random.default_rng(utt.fetch_seed()) From 4ee85c230494193c961dbd1b7072848191ea2167 Mon Sep 17 00:00:00 2001 From: "Ben F. Maier" Date: Fri, 23 May 2025 16:25:18 +0200 Subject: [PATCH 5/5] overwrite ruff --- tests/link/jax/test_slinalg.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 7f446a6b6d..b2b722f8ba 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -126,35 +126,22 @@ def test_jax_solve(): "A_size, b_size, b_ndim", [ ( - ( - 5, - 5, - ), + (5, 5), (5,), 1, ), ( - ( - 5, - 5, - ), + (5, 5), (5, 1), 2, ), ( - ( - 5, - 5, - ), + (5, 5), (1, 5), 1, ), ( - ( - 4, - 5, - 5, - ), + (4, 5, 5), (4, 5, 5), 2, ),