Skip to content

Commit 477fbaf

Browse files
committed
Allow inplace of Elemwise ScalarLoop
1 parent 34b91ef commit 477fbaf

File tree

4 files changed

+81
-58
lines changed

4 files changed

+81
-58
lines changed

pytensor/scalar/basic.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,19 +1302,7 @@ def __hash__(self):
13021302
def __str__(self):
13031303
if hasattr(self, "name") and self.name:
13041304
return self.name
1305-
else:
1306-
param = [
1307-
(k, v)
1308-
for k, v in self.__dict__.items()
1309-
if k
1310-
not in ("name", "_op_use_c_code", "bool", "output_types_preference")
1311-
]
1312-
if param:
1313-
classname = self.__class__.__name__
1314-
args = ", ".join(f"{k}={v}" for k, v in param)
1315-
return f"{classname}{{{args}}}"
1316-
else:
1317-
return self.__class__.__name__
1305+
return self.__class__.__name__
13181306

13191307
def c_code_cache_version(self):
13201308
return (4,)
@@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
41024090

41034091
def __init__(self, *args, **kwargs):
41044092
self.prepare_node_called = set()
4093+
super().__init__(*args, **kwargs)
41054094

41064095
def _cleanup_graph(self, inputs, outputs):
41074096
# TODO: We could convert to TensorVariable, optimize graph,

pytensor/scalar/loop.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
constant: Sequence[Variable] | None = None,
5656
until: Variable | None = None,
5757
name="ScalarLoop",
58+
**kwargs,
5859
):
5960
if constant is None:
6061
constant = []
@@ -75,7 +76,7 @@ def __init__(
7576
self.nout = len(self.outputs)
7677
self.name = name
7778

78-
super().__init__()
79+
super().__init__(**kwargs)
7980

8081
def output_types(self, input_types):
8182
return self.outputs_type
@@ -115,7 +116,7 @@ def fgraph(self):
115116
self._fgraph = fgraph
116117
return self._fgraph
117118

118-
def clone(self):
119+
def clone(self, name=None, **kwargs):
119120
if self.is_while:
120121
*update, until = self.outputs
121122
else:
@@ -127,28 +128,16 @@ def clone(self):
127128
update=update,
128129
constant=constant,
129130
until=until,
130-
name=self.name,
131+
name=self.name if name is None else name,
132+
**kwargs,
131133
)
132134

133135
@property
134136
def fn(self):
135137
raise NotImplementedError
136138

137139
def make_new_inplace(self, output_types_preference=None, name=None):
138-
"""
139-
This op.__init__ fct don't have the same parameter as other scalar op.
140-
This break the insert_inplace_optimizer optimization.
141-
This fct allow fix patch this.
142-
143-
"""
144-
d = {k: getattr(self, k) for k in self.init_param}
145-
out = self.__class__(**d)
146-
if name:
147-
out.name = name
148-
else:
149-
name = out.name
150-
super(ScalarLoop, out).__init__(output_types_preference, name)
151-
return out
140+
return self.clone(output_types_preference=output_types_preference, name=name)
152141

153142
def make_node(self, n_steps, *inputs):
154143
assert len(inputs) == self.nin - 1
@@ -229,11 +218,11 @@ def c_code_template(self):
229218
c: f"%(i{int(i)})s"
230219
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
231220
}
232-
update_subd = {
221+
out_subd = {
233222
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
234223
}
235224
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
236-
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd}
225+
subd = {**carry_subd, **constant_subd, **until_subd}
237226

238227
for var in fgraph.variables:
239228
if var.owner is None:
@@ -257,11 +246,11 @@ def c_code_template(self):
257246
_c_code += "bool until = 1;\n\n"
258247

259248
# Copy carried inputs
260-
for i, (var, name) in enumerate(carry_subd.items()):
261-
copy_var_name = f"{name}_copy{i}"
262-
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n"
263-
carry_subd[var] = copy_var_name
264-
subd[var] = copy_var_name
249+
for i, (var, name) in enumerate(carry_subd.items(), start=1):
250+
carry_var_name = f"{name}_carry{i}"
251+
_c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n"
252+
carry_subd[var] = carry_var_name
253+
subd[var] = carry_var_name
265254

266255
# _c_code += 'printf("inputs=[");'
267256
# for i in range(1, len(fgraph.inputs)):
@@ -270,9 +259,8 @@ def c_code_template(self):
270259

271260
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"
272261

273-
self.nodenames = [
274-
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort())
275-
]
262+
# Used by self.c_support_code_apply
263+
self.nodenames = nodenames = []
276264

277265
i = 0
278266
for j, node in enumerate(fgraph.toposort()):
@@ -282,9 +270,13 @@ def c_code_template(self):
282270
name = f"V%(id)s_tmp{int(i)}"
283271
subd[output] = name
284272
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
273+
274+
nodename = f"%(nodename)s_subnode{int(j)}"
275+
nodenames.append(nodename)
276+
285277
s = node.op.c_code(
286278
node,
287-
self.nodenames[j],
279+
nodename,
288280
# Any node that depended on `init` will depend on `update` instead
289281
# The initial value of `update` was set to `init` before the loop
290282
[subd[input] for input in node.inputs],
@@ -294,10 +286,12 @@ def c_code_template(self):
294286
_c_code += s
295287
_c_code += "\n"
296288

297-
# Set the carry variables to the output variables
289+
# Update the carry variables to the output variables
298290
_c_code += "\n"
299-
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
300-
_c_code += f"{init} = {update};\n"
291+
for carry, out in zip(
292+
carry_subd.values(), fgraph.outputs[:n_update], strict=True
293+
):
294+
_c_code += f"{carry} = {subd[out]};\n"
301295

302296
# _c_code += 'printf("%%ld\\n", i);\n'
303297
# for carry in range(1, 10):
@@ -309,6 +303,10 @@ def c_code_template(self):
309303
# End of the loop
310304
_c_code += "}\n"
311305

306+
# Assign the carry variables to the outputs
307+
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True):
308+
_c_code += f"{out} = {carry};\n"
309+
312310
# Output until flag
313311
if self.is_while:
314312
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
@@ -343,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub):
343341
return res
344342

345343
def c_code_cache_version_outer(self):
346-
return (3,)
344+
return (4,)

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
from pytensor.graph.rewriting.db import SequenceDB
2626
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
27-
from pytensor.scalar.loop import ScalarLoop
2827
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
2928
from pytensor.tensor.basic import (
3029
MakeVector,
@@ -74,15 +73,6 @@ def print_profile(cls, stream, prof, level=0):
7473
for n in sorted(ndim):
7574
print(blanc, n, ndim[n], file=stream)
7675

77-
def candidate_input_idxs(self, node):
78-
# TODO: Implement specialized InplaceCompositeOptimizer with logic
79-
# needed to correctly assign inplace for multi-output Composites
80-
# and ScalarLoops
81-
if isinstance(node.op.scalar_op, ScalarLoop):
82-
return []
83-
else:
84-
return range(len(node.outputs))
85-
8676
def apply(self, fgraph):
8777
r"""
8878
@@ -173,7 +163,7 @@ def apply(self, fgraph):
173163

174164
baseline = op.inplace_pattern
175165
candidate_outputs = [
176-
i for i in self.candidate_input_idxs(node) if i not in baseline
166+
i for i in range(len(node.outputs)) if i not in baseline
177167
]
178168
# node inputs that are Constant, already destroyed,
179169
# or fgraph protected inputs and fgraph outputs can't be used as
@@ -190,7 +180,7 @@ def apply(self, fgraph):
190180
]
191181
else:
192182
baseline = []
193-
candidate_outputs = self.candidate_input_idxs(node)
183+
candidate_outputs = range(len(node.outputs))
194184
# node inputs that are Constant, already destroyed,
195185
# fgraph protected inputs and fgraph outputs can't be used as inplace
196186
# target.

tests/scalar/test_loop.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
import pytest
55

6-
from pytensor import Mode, function
6+
from pytensor import In, Mode, function
7+
from pytensor.compile import get_default_mode
78
from pytensor.scalar import (
89
Composite,
910
as_scalar,
@@ -18,6 +19,8 @@
1819
)
1920
from pytensor.scalar.loop import ScalarLoop
2021
from pytensor.tensor import exp as tensor_exp
22+
from pytensor.tensor import lvector
23+
from pytensor.tensor.elemwise import Elemwise
2124

2225

2326
mode = pytest.mark.parametrize(
@@ -255,3 +258,46 @@ def test_inner_loop(mode):
255258
out16,
256259
3**2 + 2.5,
257260
)
261+
262+
263+
@pytest.mark.parametrize("mutate_arg_idx", (0, 1, 2, 3))
264+
def test_elemwise_inplace(mutate_arg_idx):
265+
x0 = int64("x0")
266+
y0 = int64("y0")
267+
c = int64("c")
268+
x = x0 - y0 + c
269+
y = y0 - x0 + c
270+
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[c], update=[x, y]))
271+
272+
n_steps = lvector("n_steps")
273+
x0v = lvector("x0")
274+
y0v = lvector("y0")
275+
cv = lvector("c")
276+
xv, yv = op(n_steps, x0v, y0v, cv)
277+
278+
inputs = [
279+
In(inp, mutable=i == mutate_arg_idx)
280+
for i, inp in enumerate([n_steps, x0v, y0v, cv])
281+
]
282+
283+
fn = function(
284+
inputs,
285+
[xv, yv],
286+
mode=get_default_mode().including("inplace"),
287+
)
288+
fn.dprint()
289+
elem_op = fn.maker.fgraph.outputs[0].owner.op
290+
assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop)
291+
destroy_map = elem_op.destroy_map
292+
assert destroy_map == {0: [mutate_arg_idx]}
293+
294+
n_test = np.array([1, 4, 8], dtype="int64")
295+
x0v_test = np.array([0, 0, 0], dtype="int64")
296+
y0v_test = np.array([1, 1, 1], dtype="int64")
297+
cv_test = np.array([0, 0, 0], dtype="int64")
298+
299+
xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test)
300+
# Check the outputs are the destroyed inputs
301+
assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx]
302+
np.testing.assert_allclose(xv_res, [-1, -8, -128])
303+
np.testing.assert_allclose(yv_res, [1, 8, 128])

0 commit comments

Comments
 (0)