Skip to content

Commit 148477c

Browse files
committed
Respect check_finite in LU decomposition rewrites
1 parent 6fb515d commit 148477c

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414
from pytensor.tensor.variable import TensorVariable
1515

1616

17-
def decompose_A(A, assume_a):
17+
def decompose_A(A, assume_a, check_finite):
1818
if assume_a == "gen":
19-
return lu_factor(A, check_finite=False)
19+
return lu_factor(A, check_finite=check_finite)
2020
else:
2121
raise NotImplementedError
2222

2323

24-
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
25-
if assume_a == "gen":
26-
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
24+
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
25+
if core_solve_op.assume_a == "gen":
26+
return lu_solve(
27+
A_decomp,
28+
b,
29+
trans=transposed,
30+
b_ndim=core_solve_op.b_ndim,
31+
check_finite=core_solve_op.check_finite,
32+
)
2733
else:
2834
raise NotImplementedError
2935

@@ -102,14 +108,19 @@ def find_solve_clients(var, assume_a):
102108
):
103109
return None
104110

105-
A_decomp = decompose_A(A, assume_a=assume_a)
111+
# If any Op had check_finite=True, we also do it for the LU decomposition
112+
check_finite_decomp = False
113+
for client, _ in A_solve_clients_and_transpose:
114+
if client.op.core_op.check_finite:
115+
check_finite_decomp = True
116+
break
117+
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
106118

107119
replacements = {}
108120
for client, transposed in A_solve_clients_and_transpose:
109121
_, b = client.inputs
110-
b_ndim = client.op.core_op.b_ndim
111122
new_x = solve_lu_decomposed_system(
112-
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
123+
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
113124
)
114125
[old_x] = client.outputs
115126
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)

pytensor/tensor/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ def tensor(
793793
try:
794794
# Help catching errors with the new tensor API
795795
# Many single letter strings are valid sctypes
796-
if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type):
796+
if str(name) == "floatX" or (len(str(name)) > 2 and np.dtype(name).type):
797797
raise ValueError(
798798
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
799799
"This name looks like a dtype, which you should pass as a keyword argument only."

tests/tensor/linalg/test_rewriting.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,36 @@ def test_lu_decomposition_reused_scan(transposed):
161161
resx1 = fn_opt(A_test, x0_test)
162162
rtol = 1e-7 if config.floatX == "float64" else 1e-6
163163
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
164+
165+
166+
def test_lu_decomposition_reused_preserves_check_finite():
167+
# Check that the LU decomposition rewrite preserves the check_finite flag
168+
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
169+
170+
A = tensor("A", shape=(2, 2))
171+
b1 = tensor("b1", shape=(2,))
172+
b2 = tensor("b2", shape=(2,))
173+
174+
x1 = solve(A, b1, assume_a="gen", check_finite=True)
175+
x2 = solve(A, b2, assume_a="gen", check_finite=False)
176+
fn_opt = function(
177+
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
178+
)
179+
opt_nodes = fn_opt.maker.fgraph.apply_nodes
180+
assert count_vanilla_solve_nodes(opt_nodes) == 0
181+
assert count_lu_decom_nodes(opt_nodes) == 1
182+
assert count_lu_solve_nodes(opt_nodes) == 2
183+
184+
# We should get an error if A or b1 is non finite
185+
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
186+
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
187+
b2_valid = np.array([1, 1], dtype=b2.type.dtype)
188+
189+
assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
190+
assert fn_opt(
191+
A_valid, b1_valid, b2_valid * np.nan
192+
) # Should not raise (also fine on most LAPACK implementations?)
193+
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
194+
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
195+
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
196+
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)

0 commit comments

Comments
 (0)