Skip to content

Commit a8cc03c

Browse files
committed
Remove until_condition_failed in ScalarLoop
This was problematic when OpenMP was used in the Elemwise outer loop We add one extra output flag stating whether iteration converged or not. This however breaks Hyp2F1 grad in python mode because it goes beyond the Elemwise limit on number of operands. To fix it we split the grad when on python mode
1 parent df2ffe4 commit a8cc03c

File tree

5 files changed

+145
-128
lines changed

5 files changed

+145
-128
lines changed

pytensor/scalar/loop.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import warnings
21
from copy import copy
32
from itertools import chain
4-
from textwrap import dedent
5-
from typing import Literal, Optional, Sequence, Tuple
3+
from typing import Optional, Sequence, Tuple, cast
64

75
from pytensor.compile import rebuild_collect_shared
86
from pytensor.graph import Constant, FunctionGraph, Variable, clone
@@ -14,7 +12,33 @@ class ScalarLoop(ScalarInnerGraphOp):
1412
"""Scalar Op that encapsulates a scalar loop operation.
1513
1614
This Op can be used for the gradient of other Scalar Ops.
17-
It is much more restricted that `Scan` in that the entire inner graph must be composed of Scalar operations.
15+
It is much more restricted than `Scan` in that the entire inner graph
16+
must be composed of Scalar operations, and all inputs and outputs must be ScalarVariables.
17+
18+
The pseudocode of the computation performed by this Op looks like the following:
19+
20+
```python
21+
def scalar_for_loop(fn, n_steps, init, update, constant):
22+
for i in range(n_steps):
23+
state = fn(*state, *constant)
24+
return state
25+
```
26+
27+
When an until condition is present it behaves like this:
28+
29+
```python
30+
def scalar_while_loop(fn, n_steps, init, update, constant):
31+
# If n_steps <= 0, we skip the loop altogether.
32+
# This does not count as a "failure"
33+
done = True
34+
35+
for i in range(n_steps):
36+
*state, done = fn(*state, *constant)
37+
if done:
38+
break
39+
40+
return *state, done
41+
```
1842
1943
"""
2044

@@ -23,7 +47,6 @@ class ScalarLoop(ScalarInnerGraphOp):
2347
"update",
2448
"constant",
2549
"until",
26-
"until_condition_failed",
2750
)
2851

2952
def __init__(
@@ -32,14 +55,8 @@ def __init__(
3255
update: Sequence[Variable],
3356
constant: Optional[Sequence[Variable]] = None,
3457
until: Optional[Variable] = None,
35-
until_condition_failed: Literal["ignore", "warn", "raise"] = "warn",
3658
name="ScalarLoop",
3759
):
38-
if until_condition_failed not in ["ignore", "warn", "raise"]:
39-
raise ValueError(
40-
f"Invalid until_condition_failed: {until_condition_failed}"
41-
)
42-
4360
if constant is None:
4461
constant = []
4562
if not len(init) == len(update):
@@ -52,12 +69,13 @@ def __init__(
5269
self.outputs = copy(outputs)
5370
self.inputs = copy(inputs)
5471

72+
self.is_while = bool(until)
5573
self.inputs_type = tuple(input.type for input in inputs)
5674
self.outputs_type = tuple(output.type for output in outputs)
75+
if self.is_while:
76+
self.outputs_type = self.outputs_type + (cast(Variable, until).type,)
5777
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
58-
self.nout = len(outputs) # until is not output
59-
self.is_while = bool(until)
60-
self.until_condition_failed = until_condition_failed
78+
self.nout = len(outputs) + (1 if self.is_while else 0)
6179
self.name = name
6280
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
6381
super().__init__()
@@ -135,7 +153,6 @@ def clone(self):
135153
update=update,
136154
constant=constant,
137155
until=until,
138-
until_condition_failed=self.until_condition_failed,
139156
name=self.name,
140157
)
141158

@@ -191,7 +208,6 @@ def make_node(self, n_steps, *inputs):
191208
update=cloned_update,
192209
constant=cloned_constant,
193210
until=cloned_until,
194-
until_condition_failed=self.until_condition_failed,
195211
name=self.name,
196212
)
197213
node = op.make_node(n_steps, *inputs)
@@ -209,17 +225,8 @@ def perform(self, node, inputs, output_storage):
209225
*carry, until = inner_fn(*carry, *constant)
210226
if until:
211227
break
228+
carry.append(until)
212229

213-
if not until: # no-break
214-
if self.until_condition_failed == "raise":
215-
raise RuntimeError(
216-
f"Until condition in ScalarLoop {self.name} not reached!"
217-
)
218-
elif self.until_condition_failed == "warn":
219-
warnings.warn(
220-
f"Until condition in ScalarLoop {self.name} not reached!",
221-
RuntimeWarning,
222-
)
223230
else:
224231
if n_steps < 0:
225232
raise ValueError("ScalarLoop does not have a termination condition.")
@@ -324,27 +331,12 @@ def c_code_template(self):
324331
if self.is_while:
325332
_c_code += "\nif(until){break;}\n"
326333

334+
# End of the loop
327335
_c_code += "}\n"
328336

329-
# End of the loop
337+
# Output until flag
330338
if self.is_while:
331-
if self.until_condition_failed == "raise":
332-
_c_code += dedent(
333-
f"""
334-
if (!until) {{
335-
PyErr_SetString(PyExc_RuntimeError, "Until condition in ScalarLoop {self.name} not reached!");
336-
%(fail)s
337-
}}
338-
"""
339-
)
340-
elif self.until_condition_failed == "warn":
341-
_c_code += dedent(
342-
f"""
343-
if (!until) {{
344-
PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop {self.name} not reached!", 1);
345-
}}
346-
"""
347-
)
339+
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
348340

349341
_c_code += "}\n"
350342

@@ -376,13 +368,4 @@ def c_code(self, node, nodename, inames, onames, sub):
376368
return res
377369

378370
def c_code_cache_version_outer(self):
379-
return (1,)
380-
381-
def __eq__(self, other):
382-
return (
383-
super().__eq__(other)
384-
and self.until_condition_failed == other.until_condition_failed
385-
)
386-
387-
def __hash__(self):
388-
return hash((super().__hash__(), self.until_condition_failed))
371+
return (2,)

pytensor/scalar/math.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,6 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=Scal
703703
constant=constant_,
704704
update=update_,
705705
until=until_,
706-
until_condition_failed="warn",
707706
name=name,
708707
)
709708
return op(n_steps, *init, *constant)
@@ -747,9 +746,10 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
747746

748747
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
749748
constant = [log_x]
750-
sum_a, *_ = _make_scalar_loop(
749+
sum_a, *_, sum_a_converges = _make_scalar_loop(
751750
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
752751
)
752+
sum_a = switch(sum_a_converges, sum_a, np.nan)
753753

754754
# Second loop
755755
n = np.array(0, dtype="int32")
@@ -772,9 +772,10 @@ def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
772772

773773
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
774774
constant = [log_x]
775-
sum_b, *_ = _make_scalar_loop(
775+
sum_b, *_, sum_b_converges = _make_scalar_loop(
776776
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
777777
)
778+
sum_b = switch(sum_b_converges, sum_b, np.nan)
778779

779780
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
780781
return grad_approx
@@ -877,9 +878,10 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
877878

878879
init = [sum_b0, log_s, s_sign, log_delta, n]
879880
constant = [k, log_x]
880-
sum_b, *_ = _make_scalar_loop(
881+
sum_b, *_, sum_b_converges = _make_scalar_loop(
881882
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
882883
)
884+
sum_b = switch(sum_b_converges, sum_b, np.nan)
883885
grad_approx_b = (
884886
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
885887
)
@@ -1547,10 +1549,10 @@ def inner_loop(
15471549

15481550
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
15491551
constant = [f, p, q, K, dK]
1550-
grad, *_ = _make_scalar_loop(
1552+
grad, *_, grad_converges = _make_scalar_loop(
15511553
max_iters, init, constant, inner_loop, name="betainc_grad"
15521554
)
1553-
return grad
1555+
return switch(grad_converges, grad, np.nan)
15541556

15551557
# Input validation
15561558
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
@@ -1752,10 +1754,10 @@ def inner_loop(*args):
17521754

17531755
init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
17541756
constant = [a, b, c, log_z, sign_z]
1755-
loop_outs = _make_scalar_loop(
1757+
*loop_outs, converges = _make_scalar_loop(
17561758
max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
17571759
)
1758-
return loop_outs[: len(wrt)]
1760+
return *loop_outs[: len(wrt)], converges
17591761

17601762

17611763
def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
@@ -1792,7 +1794,7 @@ def is_nonpositive_integer(x):
17921794
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
17931795
z_is_zero = eq(z, 0)
17941796
converges = check_2f1_converges(a, b, c, z)
1795-
grads = _grad_2f1_loop(
1797+
*grads, grad_converges = _grad_2f1_loop(
17961798
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
17971799
)
17981800

pytensor/tensor/rewriting/elemwise.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,30 @@ def local_careduce_fusion(fgraph, node):
12191219
)
12201220

12211221

1222+
def _rebuild_partial_2f1grad_loop(node, wrt):
1223+
a, b, c, log_z, sign_z = node.inputs[-5:]
1224+
z = exp(log_z) * sign_z
1225+
1226+
# Reconstruct scalar loop with relevant outputs
1227+
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
1228+
new_loop_op = _grad_2f1_loop(
1229+
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
1230+
)[0].owner.op
1231+
1232+
# Reconstruct elemwise loop
1233+
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
1234+
n_steps = node.inputs[0]
1235+
init_grad_vars = node.inputs[1:10]
1236+
other_inputs = node.inputs[10:]
1237+
1238+
init_grads = init_grad_vars[: len(wrt)]
1239+
init_gs = init_grad_vars[3 : 3 + len(wrt)]
1240+
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
1241+
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
1242+
1243+
return new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
1244+
1245+
12221246
@register_specialize
12231247
@node_rewriter([Elemwise])
12241248
def local_useless_2f1grad_loop(fgraph, node):
@@ -1240,38 +1264,65 @@ def local_useless_2f1grad_loop(fgraph, node):
12401264
if sum(grad_var_is_used) == 3:
12411265
return None
12421266

1243-
# Check that None of the remaining vars is used anywhere
1244-
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
1245-
return None
1267+
*other_vars, converges = node.outputs[3:]
12461268

1247-
a, b, c, log_z, sign_z = node.inputs[-5:]
1248-
z = exp(log_z) * sign_z
1269+
# Check that None of the remaining vars (except the converge flag) is used anywhere
1270+
if any(bool(fgraph.clients.get(v)) for v in other_vars):
1271+
return None
12491272

1250-
# Reconstruct scalar loop with relevant outputs
1251-
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
12521273
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
1253-
new_loop_op = _grad_2f1_loop(
1254-
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
1255-
)[0].owner.op
1274+
*new_outs, new_converges = _rebuild_partial_2f1grad_loop(node, wrt=wrt)
12561275

1257-
# Reconstruct elemwise loop
1258-
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
1259-
n_steps = node.inputs[0]
1260-
init_grad_vars = node.inputs[1:10]
1261-
other_inputs = node.inputs[10:]
1262-
1263-
init_grads = init_grad_vars[: len(wrt)]
1264-
init_gs = init_grad_vars[3 : 3 + len(wrt)]
1265-
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
1266-
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
1267-
1268-
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
1269-
1270-
replacements = {}
1276+
replacements = {converges: new_converges}
12711277
i = 0
12721278
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
12731279
if not is_used:
12741280
continue
12751281
replacements[grad_var] = new_outs[i]
12761282
i += 1
12771283
return replacements
1284+
1285+
1286+
@node_rewriter([Elemwise])
1287+
def split_2f1grad_loop(fgraph, node):
1288+
"""
1289+
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
1290+
1291+
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
1292+
"""
1293+
loop_op = node.op.scalar_op
1294+
1295+
if not isinstance(loop_op, Grad2F1Loop):
1296+
return None
1297+
1298+
grad_related_vars = node.outputs[:-4]
1299+
# local_useless_2f1grad_loop was used, we should be safe
1300+
if len(grad_related_vars) // 3 != 3:
1301+
return None
1302+
1303+
grad_vars = grad_related_vars[:3]
1304+
*other_vars, converges = node.outputs[3:]
1305+
1306+
# Check that None of the remaining vars is used anywhere
1307+
if any(bool(fgraph.clients.get(v)) for v in other_vars):
1308+
return None
1309+
1310+
new_grad0, new_grad1, *_, new_converges01 = _rebuild_partial_2f1grad_loop(
1311+
node, wrt=[0, 1]
1312+
)
1313+
new_grad2, *_, new_converges2 = _rebuild_partial_2f1grad_loop(node, wrt=[2])
1314+
1315+
replacements = {
1316+
converges: new_converges01 & new_converges2,
1317+
grad_vars[0]: new_grad0,
1318+
grad_vars[1]: new_grad1,
1319+
grad_vars[2]: new_grad2,
1320+
}
1321+
return replacements
1322+
1323+
1324+
compile.optdb["py_only"].register( # type: ignore
1325+
"split_2f1grad_loop",
1326+
split_2f1grad_loop,
1327+
"fast_compile",
1328+
)

0 commit comments

Comments
 (0)