Skip to content

Commit 96f753b

Browse files
Remove incorrect solve usage in psd_solve_with_chol rewrite (#575)
* Use `solve_triangular` instead of in `psd_solve_with_chol` * Add unittest for `psd_solve_with_chol` * Specify `mode=FAST_RUN` in test * Relax `test_psd_solve_with_chol` `atol` and `rtol` for half-precision tests
1 parent 65967fe commit 96f753b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def psd_solve_with_chol(fgraph, node):
215215
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
216216
# __if__ no other Op makes use of the L matrix during the
217217
# stabilization
218-
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
219-
x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
218+
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
219+
x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2)
220220
return [x]
221221

222222

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,33 @@ def test_local_det_chol():
241241
assert not any(isinstance(node, Det) for node in nodes)
242242

243243

244+
def test_psd_solve_with_chol():
245+
X = matrix("X")
246+
X.tag.psd = True
247+
X_inv = pt.linalg.solve(X, pt.identity_like(X))
248+
249+
f = function([X], X_inv, mode="FAST_RUN")
250+
251+
nodes = f.maker.fgraph.apply_nodes
252+
253+
assert not any(isinstance(node.op, Solve) for node in nodes)
254+
assert any(isinstance(node.op, Cholesky) for node in nodes)
255+
assert any(isinstance(node.op, SolveTriangular) for node in nodes)
256+
257+
# Numeric test
258+
rng = np.random.default_rng(sum(map(ord, "test_psd_solve_with_chol")))
259+
260+
L = rng.normal(size=(5, 5)).astype(config.floatX)
261+
X_psd = L @ L.T
262+
X_psd_inv = f(X_psd)
263+
assert_allclose(
264+
X_psd_inv,
265+
np.linalg.inv(X_psd),
266+
atol=1e-4 if config.floatX == "float32" else 1e-8,
267+
rtol=1e-4 if config.floatX == "float32" else 1e-8,
268+
)
269+
270+
244271
class TestBatchedVectorBSolveToMatrixBSolve:
245272
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
246273

0 commit comments

Comments
 (0)