Skip to content

Commit 34b91ef

Browse files
committed
Allow inplace of Elemwise Composite with multiple outputs
1 parent 0699b48 commit 34b91ef

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

pytensor/scalar/basic.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4441,16 +4441,12 @@ def c_code_template(self):
44414441
if hasattr(self, "_c_code"):
44424442
return self._c_code
44434443

4444-
subd = dict(
4445-
chain(
4446-
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
4447-
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
4448-
)
4449-
)
4444+
fg = self.fgraph
4445+
subd = {e: f"%(i{int(i)})s" for i, e in enumerate(fg.inputs)}
44504446

4451-
for var in self.fgraph.variables:
4447+
for var in fg.variables:
44524448
if var.owner is None:
4453-
if var not in self.fgraph.inputs:
4449+
if var not in fg.inputs:
44544450
# This is an orphan
44554451
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
44564452
subd[var] = f"({var.type.c_literal(var.data)})"
@@ -4465,30 +4461,35 @@ def c_code_template(self):
44654461
# flag for elemwise ops to check.
44664462
self.inner_float16 = True
44674463

4468-
_c_code = "{\n"
4469-
self.nodenames = [
4470-
f"%(nodename)s_subnode{int(j)}"
4471-
for j, n in enumerate(self.fgraph.toposort())
4472-
]
4464+
self.nodenames = nodenames = [] # Used by self.c_support_code_apply
44734465

4466+
_c_code = "{\n"
44744467
i = 0
4475-
for j, node in enumerate(self.fgraph.toposort()):
4468+
for j, node in enumerate(fg.toposort()):
44764469
for output in node.outputs:
44774470
if output not in subd:
44784471
i += 1
44794472
name = f"V%(id)s_tmp{int(i)}"
44804473
subd[output] = name
44814474
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
4475+
4476+
nodename = f"%(nodename)s_subnode{int(j)}"
4477+
nodenames.append(nodename)
4478+
44824479
s = node.op.c_code(
44834480
node,
4484-
self.nodenames[j],
4481+
nodename,
44854482
[subd[input] for input in node.inputs],
44864483
[subd[output] for output in node.outputs],
44874484
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
44884485
)
44894486
_c_code += s
44904487
_c_code += "\n"
44914488

4489+
# Copy the temporary outputs to the real outputs
4490+
for i, output in enumerate(fg.outputs):
4491+
_c_code += f"%(o{int(i)})s = {subd[output]};\n"
4492+
44924493
_c_code += "}\n"
44934494

44944495
self._c_code = _c_code
@@ -4512,7 +4513,7 @@ def c_code(self, node, nodename, inames, onames, sub):
45124513
return self.c_code_template % d
45134514

45144515
def c_code_cache_version_outer(self) -> tuple[int, ...]:
4515-
return (5,)
4516+
return (6,)
45164517

45174518

45184519
class Compositef32:

pytensor/tensor/rewriting/elemwise.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def candidate_input_idxs(self, node):
8080
# and ScalarLoops
8181
if isinstance(node.op.scalar_op, ScalarLoop):
8282
return []
83-
if isinstance(node.op.scalar_op, ps.Composite) and (len(node.outputs) > 1):
84-
return []
8583
else:
8684
return range(len(node.outputs))
8785

tests/tensor/rewriting/test_elemwise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,8 @@ def test_add_mul_fusion_inplace(self):
11041104
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
11051105
)
11061106

1107-
def test_fusion_multiout_inplace(self):
1107+
@pytest.mark.parametrize("linker", ["cvm", "py"])
1108+
def test_fusion_multiout_inplace(self, linker):
11081109
x = vector("x")
11091110

11101111
# Create Composite where inplacing the first non-constant output would corrupt the second output
@@ -1118,17 +1119,16 @@ def test_fusion_multiout_inplace(self):
11181119
f = pytensor.function(
11191120
[In(x, mutable=True)],
11201121
outs,
1121-
mode=self.mode.including("inplace"),
1122+
mode=Mode(linker=linker, optimizer=self.rewrites.including("inplace")),
11221123
)
11231124
(composite_node,) = f.maker.fgraph.apply_nodes
11241125

1125-
# Destroy map must be None or the last toposorted output
11261126
destroy_map = composite_node.op.destroy_map
1127-
assert (destroy_map == {}) or (
1128-
destroy_map == {1: [composite_node.inputs.index(x)]}
1129-
)
1127+
assert destroy_map == {0: [0]}
11301128

1131-
res = f([0, 1, 2])
1129+
inp = np.array([0, 1, 2], dtype=config.floatX)
1130+
res = f(inp)
1131+
assert not np.allclose(inp, [0, 1, 2])
11321132
assert np.allclose(res[0], [1, 2, 3])
11331133
assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2]))
11341134

0 commit comments

Comments
 (0)