Skip to content

Commit 453746f

Browse files
committed
Allow inplace ScalarLoop
1 parent 27071c2 commit 453746f

File tree

2 files changed

+52
-35
lines changed

2 files changed

+52
-35
lines changed

pytensor/scalar/loop.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,11 @@ def c_code_template(self):
218218
c: f"%(i{int(i)})s"
219219
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
220220
}
221-
update_subd = {
221+
out_subd = {
222222
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
223223
}
224224
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
225-
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd}
225+
subd = {**carry_subd, **constant_subd, **until_subd}
226226

227227
for var in fgraph.variables:
228228
if var.owner is None:
@@ -246,11 +246,11 @@ def c_code_template(self):
246246
_c_code += "bool until = 1;\n\n"
247247

248248
# Copy carried inputs
249-
for i, (var, name) in enumerate(carry_subd.items()):
250-
copy_var_name = f"{name}_copy{i}"
251-
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n"
252-
carry_subd[var] = copy_var_name
253-
subd[var] = copy_var_name
249+
for i, (var, name) in enumerate(carry_subd.items(), start=1):
250+
carry_var_name = f"{name}_carry{i}"
251+
_c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n"
252+
carry_subd[var] = carry_var_name
253+
subd[var] = carry_var_name
254254

255255
# _c_code += 'printf("inputs=[");'
256256
# for i in range(1, len(fgraph.inputs)):
@@ -259,9 +259,8 @@ def c_code_template(self):
259259

260260
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"
261261

262-
self.nodenames = [
263-
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort())
264-
]
262+
# Used by self.c_support_code_apply
263+
self.nodenames = nodenames = []
265264

266265
i = 0
267266
for j, node in enumerate(fgraph.toposort()):
@@ -271,9 +270,13 @@ def c_code_template(self):
271270
name = f"V%(id)s_tmp{int(i)}"
272271
subd[output] = name
273272
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
273+
274+
nodename = f"%(nodename)s_subnode{int(j)}"
275+
nodenames.append(nodename)
276+
274277
s = node.op.c_code(
275278
node,
276-
self.nodenames[j],
279+
nodename,
277280
# Any node that depended on `init` will depend on `update` instead
278281
# The initial value of `update` was set to `init` before the loop
279282
[subd[input] for input in node.inputs],
@@ -283,10 +286,12 @@ def c_code_template(self):
283286
_c_code += s
284287
_c_code += "\n"
285288

286-
# Set the carry variables to the output variables
289+
# Update the carry variables to the output variables
287290
_c_code += "\n"
288-
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
289-
_c_code += f"{init} = {update};\n"
291+
for carry, out in zip(
292+
carry_subd.values(), fgraph.outputs[:n_update], strict=True
293+
):
294+
_c_code += f"{carry} = {subd[out]};\n"
290295

291296
# _c_code += 'printf("%%ld\\n", i);\n'
292297
# for carry in range(1, 10):
@@ -298,6 +303,10 @@ def c_code_template(self):
298303
# End of the loop
299304
_c_code += "}\n"
300305

306+
# Assign the carry variables to the outputs
307+
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True):
308+
_c_code += f"{out} = {carry};\n"
309+
301310
# Output until flag
302311
if self.is_while:
303312
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
@@ -332,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub):
332341
return res
333342

334343
def c_code_cache_version_outer(self):
335-
return (3,)
344+
return (4,)

tests/scalar/test_loop.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from pytensor.scalar.loop import ScalarLoop
2121
from pytensor.tensor import exp as tensor_exp
22-
from pytensor.tensor import vector
22+
from pytensor.tensor import lvector
2323
from pytensor.tensor.elemwise import Elemwise
2424

2525

@@ -260,35 +260,43 @@ def test_inner_loop(mode):
260260
)
261261

262262

263-
def test_elemwise_inplace():
264-
x0 = float64("x0")
265-
y0 = float64("y0")
266-
x = x0 - y0
267-
y = y0 - x0
268-
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[], update=[x, y]))
269-
270-
n_steps = vector("n_steps", dtype="int64")
271-
x0v = vector("x0")
272-
y0v = vector("y0")
273-
xv, yv = op(n_steps, x0v, y0v)
263+
@pytest.mark.parametrize("mutate_arg_idx", (0, 1, 2, 3))
264+
def test_elemwise_inplace(mutate_arg_idx):
265+
x0 = int64("x0")
266+
y0 = int64("y0")
267+
c = int64("c")
268+
x = x0 - y0 + c
269+
y = y0 - x0 + c
270+
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[c], update=[x, y]))
271+
272+
n_steps = lvector("n_steps")
273+
x0v = lvector("x0")
274+
y0v = lvector("y0")
275+
cv = lvector("c")
276+
xv, yv = op(n_steps, x0v, y0v, cv)
277+
278+
inputs = [
279+
In(inp, mutable=i == mutate_arg_idx)
280+
for i, inp in enumerate([n_steps, x0v, y0v, cv])
281+
]
274282

275283
fn = function(
276-
[In(n_steps, mutable=True), In(x0v, mutable=True), In(y0v, mutable=True)],
284+
inputs,
277285
[xv, yv],
278286
mode=get_default_mode().including("inplace"),
279287
)
280288
elem_op = fn.maker.fgraph.outputs[0].owner.op
281289
assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop)
282290
destroy_map = elem_op.destroy_map
283-
assert destroy_map in ({0: [1], 1: [2]}, {0: [2], 1: [2]})
291+
assert destroy_map == {0: [mutate_arg_idx]}
284292

285-
n_test = np.array([1, 4, 8], dtype="int32")
286-
x0v_test = np.array([0, 0, 0], dtype=x0v.dtype)
287-
y0v_test = np.array([1, 1, 1], dtype=y0v.dtype)
293+
n_test = np.array([1, 4, 8], dtype="int64")
294+
x0v_test = np.array([0, 0, 0], dtype="int64")
295+
y0v_test = np.array([1, 1, 1], dtype="int64")
296+
cv_test = np.array([0, 0, 0], dtype="int64")
288297

289-
xv_res, yv_res = fn(n_test, x0v_test, y0v_test)
298+
xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test)
290299
# Check the outputs are the destroyed inputs
291-
assert xv_res is (x0v_test, y0v_test)[destroy_map[0][0] - 1]
292-
assert yv_res is (x0v_test, y0v_test)[destroy_map[1][0] - 1]
300+
assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx]
293301
np.testing.assert_allclose(xv_res, [-1, -8, -128])
294302
np.testing.assert_allclose(yv_res, [1, 8, 128])

0 commit comments

Comments
 (0)