Skip to content

Commit db673f0

Browse files
committed
Fix mixed dtype bug in gammaincc_grad
1 parent 53b00ea commit db673f0

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pytensor/scalar/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def approx_b(skip_loop):
854854
log_s = np.array(0.0, dtype=dtype)
855855
s_sign = np.array(1, dtype="int8")
856856
n = np.array(1, dtype="int32")
857-
log_delta = log_s - 2 * log(k)
857+
log_delta = log_s - 2 * log(k).astype(dtype)
858858

859859
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
860860
delta = exp(log_delta)

tests/scalar/test_math.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1+
import itertools
2+
13
import numpy as np
4+
import pytest
25
import scipy.special as sp
36

47
import pytensor.tensor as at
58
from pytensor import function
69
from pytensor.compile.mode import Mode
10+
from pytensor.graph import ancestors
711
from pytensor.graph.fg import FunctionGraph
812
from pytensor.link.c.basic import CLinker
13+
from pytensor.scalar import ScalarLoop, float32, float64, int32
914
from pytensor.scalar.math import (
1015
betainc,
1116
betainc_grad,
1217
gammainc,
1318
gammaincc,
1419
gammal,
1520
gammau,
21+
hyp2f1,
1622
)
1723
from tests.link.test_link import make_function
1824

@@ -89,3 +95,32 @@ def test_betainc_derivative_nan():
8995
assert np.isnan(test_func(1, 1, 2))
9096
assert np.isnan(test_func(1, -1, 1))
9197
assert np.isnan(test_func(1, 1, -1))
98+
99+
100+
@pytest.mark.parametrize(
101+
"op, scalar_loop_grads",
102+
[
103+
(gammainc, [0]),
104+
(gammaincc, [0]),
105+
(betainc, [0, 1]),
106+
(hyp2f1, [0, 1, 2]),
107+
],
108+
)
109+
def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
110+
nin = op.nin
111+
for types in itertools.product((float32, float64, int32), repeat=nin):
112+
inputs = [type() for type in types]
113+
out = op(*inputs)
114+
wrt = [
115+
inp
116+
for idx, inp in enumerate(inputs)
117+
if idx in scalar_loop_grads and inp.type.dtype.startswith("float")
118+
]
119+
if not wrt:
120+
continue
121+
# The ScalarLoop in the graph will fail if the input types are different from the updates
122+
grad = at.grad(out, wrt=wrt)
123+
assert any(
124+
(var.owner and isinstance(var.owner.op, ScalarLoop))
125+
for var in ancestors(grad)
126+
)

0 commit comments

Comments
 (0)