Skip to content

Commit 2ba8937

Browse files
benmaierjessegrabowski
authored andcommitted
incorporate changes as asked for
1 parent f146af6 commit 2ba8937

File tree

2 files changed

+55
-26
lines changed

2 files changed

+55
-26
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def cholesky(a, lower=lower):
4646
def jax_funcify_Solve(op, **kwargs):
4747
assume_a = op.assume_a
4848
lower = op.lower
49+
b_is_vec = op.b_ndim == 1
4950

5051
if assume_a == "tridiagonal":
5152
# jax.scipy.solve does not yet support tridiagonal matrices
@@ -54,21 +55,20 @@ def solve(a, b):
5455
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
5556
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
5657
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
58+
5759
# jax requires dl and du to have the same shape as d
5860
dl = jax.numpy.pad(dl, (1, 0))
5961
du = jax.numpy.pad(du, (0, 1))
60-
# jax also requires b to be a matrix; reshape it to be a column vector if necessary
61-
b_is_vec = len(b.shape) == 1
62+
6263
if b_is_vec:
6364
b = jax.numpy.expand_dims(b, -1)
6465

6566
res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)
6667

6768
if b_is_vec:
68-
# if b is a vector, return a vector
69-
return res.flatten()
70-
else:
71-
return res
69+
return jax.numpy.squeeze(res, -1)
70+
71+
return res
7272

7373
else:
7474
if assume_a not in ("gen", "sym", "her", "pos"):

tests/link/jax/test_slinalg.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,37 +122,66 @@ def test_jax_solve():
122122
)
123123

124124

125-
def test_jax_tridiagonal_solve():
126-
N = 10
127-
A = pt.matrix("A", shape=(N, N))
128-
b = pt.vector("b", shape=(N,))
125+
@pytest.mark.parametrize(
126+
"A_size, b_size, b_ndim",
127+
[
128+
(
129+
(
130+
5,
131+
5,
132+
),
133+
(5,),
134+
1,
135+
),
136+
(
137+
(
138+
5,
139+
5,
140+
),
141+
(5, 1),
142+
2,
143+
),
144+
(
145+
(
146+
5,
147+
5,
148+
),
149+
(1, 5),
150+
1,
151+
),
152+
(
153+
(
154+
4,
155+
5,
156+
5,
157+
),
158+
(4, 5, 5),
159+
2,
160+
),
161+
],
162+
ids=["basic_vector", "basic_matrix", "vector_broadcasted", "fully_batched"],
163+
)
164+
def test_jax_tridiagonal_solve(A_size: tuple, b_size: tuple, b_ndim: int):
165+
A = pt.tensor("A", shape=A_size)
166+
b = pt.tensor("b", shape=b_size)
129167

130-
out = pt.linalg.solve(A, b, assume_a="tridiagonal")
168+
out = pt.linalg.solve(A, b, assume_a="tridiagonal", b_ndim=b_ndim)
131169

132-
A_val = np.eye(N)
170+
A_val = np.zeros(A_size)
171+
N = A_size[-1]
172+
A_val[...] = np.eye(N)
133173
for i in range(N - 1):
134-
A_val[i, i + 1] = np.random.randn()
135-
A_val[i + 1, i] = np.random.randn()
174+
A_val[..., i, i + 1] = np.random.randn()
175+
A_val[..., i + 1, i] = np.random.randn()
136176

137-
b_val = np.random.randn(N)
177+
b_val = np.random.randn(*b_size)
138178

139179
compare_jax_and_py(
140180
[A, b],
141181
[out],
142182
[A_val, b_val],
143183
)
144184

145-
b_ = pt.matrix("b", shape=(N, 2))
146-
147-
out = pt.linalg.solve(A, b_, assume_a="tridiagonal")
148-
b_val = np.random.randn(N, 2)
149-
150-
compare_jax_and_py(
151-
[A, b_],
152-
[out],
153-
[A_val, b_val],
154-
)
155-
156185

157186
def test_jax_SolveTriangular():
158187
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)