Skip to content

Commit db1c161

Browse files
committed
Keep stack trace in random_make_inplace
1 parent e3d2750 commit db1c161

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pytensor/tensor/random/rewriting/basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def random_make_inplace(fgraph, node):
5050
props = op._props_dict()
5151
props["inplace"] = True
5252
new_op = type(op)(**props)
53-
return new_op.make_node(*node.inputs).outputs
53+
new_outputs = new_op.make_node(*node.inputs).outputs
54+
for old_out, new_out in zip(node.outputs, new_outputs):
55+
copy_stack_trace(old_out, new_out)
56+
return new_outputs
5457

5558
return False
5659

tests/tensor/random/rewriting/test_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.graph.basic import Constant, Variable, ancestors
1111
from pytensor.graph.fg import FunctionGraph
12-
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
12+
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, check_stack_trace
1313
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1414
from pytensor.tensor import constant
1515
from pytensor.tensor.elemwise import DimShuffle
@@ -143,6 +143,7 @@ def test_inplace_rewrites(rv_op):
143143
for a, b in zip(new_op.dist_params(new_node), op.dist_params(node))
144144
)
145145
assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data)
146+
assert check_stack_trace(f)
146147

147148

148149
@config.change_flags(compute_test_value="raise")

0 commit comments

Comments
 (0)