Skip to content

Commit cdae903

Browse files
committed
Fix TrueDiv gradient for integer inputs
1 parent aa616e6 commit cdae903

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pytensor/scalar/basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2036,7 +2036,10 @@ def grad(self, inputs, gout):
20362036
# to the output; x/y is still a function of x
20372037
# and y; it's just a step function.
20382038
if all(a.dtype in discrete_dtypes for a in (x, y)):
2039-
return [x.zeros_like(), y.zeros_like()]
2039+
return [
2040+
x.zeros_like(dtype=config.floatX),
2041+
y.zeros_like(dtype=config.floatX),
2042+
]
20402043

20412044
first_part = gz / y
20422045

0 commit comments

Comments
 (0)