Skip to content

Commit 37fb461

Browse files
committed
Fix bug in gradient of Elemwise containing multi-output scalars
1 parent 095e3c9 commit 37fb461

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/tensor/elemwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ def transform(r):
636636
return DimShuffle((), ["x"] * nd)(res)
637637

638638
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
639+
if isinstance(new_r, (list, tuple)):
640+
# Scalar Op with multiple outputs
641+
new_r = new_r[r.owner.outputs.index(r)]
639642
return new_r
640643

641644
ret = []

0 commit comments

Comments
 (0)