Skip to content

Commit 2dab9b8

Browse files
committed
Keep stack trace in random_make_inplace
1 parent 27bd9aa commit 2dab9b8

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

pytensor/tensor/random/rewriting/basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def random_make_inplace(fgraph, node):
5151
props = op._props_dict()
5252
props["inplace"] = True
5353
new_op = type(op)(**props)
54-
return new_op.make_node(*node.inputs).outputs
54+
new_outputs = new_op.make_node(*node.inputs).outputs
55+
for old_out, new_out in zip(node.outputs, new_outputs):
56+
copy_stack_trace(old_out, new_out)
57+
return new_outputs
5558

5659
return False
5760

tests/tensor/random/rewriting/test_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.compile.mode import Mode
88
from pytensor.graph.basic import Constant
99
from pytensor.graph.fg import FunctionGraph
10-
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
10+
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, check_stack_trace
1111
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1212
from pytensor.tensor import constant
1313
from pytensor.tensor.elemwise import DimShuffle
@@ -96,6 +96,8 @@ def test_inplace_rewrites():
9696
)
9797
assert np.array_equal(new_out.owner.inputs[1].data, [])
9898

99+
assert check_stack_trace(f)
100+
99101

100102
def test_inplace_rewrites_extra_props():
101103
class Test(RandomVariable):

0 commit comments

Comments
 (0)