From 1e80a108350a8094de47a94a8922913b5d195214 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 24 May 2023 16:50:05 +0200 Subject: [PATCH] Fix mixed dtype bug in gammaincc_grad --- pytensor/scalar/math.py | 2 +- tests/scalar/test_math.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index bffbb46769..8dacc61f6f 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -854,7 +854,7 @@ def approx_b(skip_loop): log_s = np.array(0.0, dtype=dtype) s_sign = np.array(1, dtype="int8") n = np.array(1, dtype="int32") - log_delta = log_s - 2 * log(k) + log_delta = log_s - 2 * log(k).astype(dtype) def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): delta = exp(log_delta) diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index ed09aa8426..1998ed5fa5 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -1,11 +1,16 @@ +import itertools + import numpy as np +import pytest import scipy.special as sp import pytensor.tensor as at from pytensor import function from pytensor.compile.mode import Mode +from pytensor.graph import ancestors from pytensor.graph.fg import FunctionGraph from pytensor.link.c.basic import CLinker +from pytensor.scalar import ScalarLoop, float32, float64, int32 from pytensor.scalar.math import ( betainc, betainc_grad, @@ -13,6 +18,7 @@ gammaincc, gammal, gammau, + hyp2f1, ) from tests.link.test_link import make_function @@ -89,3 +95,32 @@ def test_betainc_derivative_nan(): assert np.isnan(test_func(1, 1, 2)) assert np.isnan(test_func(1, -1, 1)) assert np.isnan(test_func(1, 1, -1)) + + +@pytest.mark.parametrize( + "op, scalar_loop_grads", + [ + (gammainc, [0]), + (gammaincc, [0]), + (betainc, [0, 1]), + (hyp2f1, [0, 1, 2]), + ], +) +def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads): + nin = op.nin + for types in itertools.product((float32, float64, int32), repeat=nin): + inputs = [type() for type in types] + out = op(*inputs) + wrt = [ + inp + for idx, inp in enumerate(inputs) + if idx in scalar_loop_grads and inp.type.dtype.startswith("float") + ] + if not wrt: + continue + # The ScalarLoop in the graph will fail if the input types are different from the updates + grad = at.grad(out, wrt=wrt) + assert any( + (var.owner and isinstance(var.owner.op, ScalarLoop)) + for var in ancestors(grad) + )