From 34bfe31e1e0024c3966876778a4a03239e5f61b4 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Mon, 5 Feb 2024 22:22:54 -0300 Subject: [PATCH 1/3] change gamma.c function to support np.inf special cases --- pytensor/scalar/c_code/gamma.c | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytensor/scalar/c_code/gamma.c b/pytensor/scalar/c_code/gamma.c index 7c3a5fde4c..b625378de8 100644 --- a/pytensor/scalar/c_code/gamma.c +++ b/pytensor/scalar/c_code/gamma.c @@ -218,6 +218,11 @@ DEVICE double GammaP (double n, double x) { /* --- regularized Gamma function P */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */ if (x <= 0) return 0; /* treat x = 0 as a special case */ + if (isinf(n)) { + if (isinf(x)) return NPY_NAN; + return 0; + } + if (isinf(x)) return 1; if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n)); return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); } /* GammaP() */ @@ -228,6 +233,11 @@ DEVICE double GammaQ (double n, double x) { /* --- regularized Gamma function Q */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */ if (x <= 0) return 1; /* treat x = 0 as a special case */ + if (isinf(n)) { + if (isinf(x)) return NPY_NAN; + return 1; + } + if (isinf(x)) return 0; if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n)); return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); } /* GammaQ() */ From 9765b46cbc9acc3018138507245e23aa94d9e490 Mon Sep 17 00:00:00 2001 From: amyoshino Date: Wed, 7 Feb 2024 00:01:03 -0300 Subject: [PATCH 2/3] add tests and c_code_cache_version --- pytensor/scalar/math.py | 21 +++++++++++++++++++++ tests/scalar/test_math.py | 20 ++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 7fab4d7594..6e2b18df07 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -630,6 +630,13 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) + + def c_code_cache_version(self): + v = super().c_code_cache_version() + if v: + return (2, *v) + else: + return v chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf") @@ -676,6 +683,13 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) + + def c_code_cache_version(self): + v = super().c_code_cache_version() + if v: + return (2, *v) + else: + return v gammainc = GammaInc(upgrade_to_float, name="gammainc") @@ -722,6 +736,13 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) + + def c_code_cache_version(self): + v = super().c_code_cache_version() + if v: + return (2, *v) + else: + return v gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 34567d34db..3a726977e0 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -41,6 +41,16 @@ def test_gammainc_nan_c(): assert np.isnan(test_func(-1, -1)) +def test_gammainc_inf_c(): + x1 = pt.dscalar() + x2 = pt.dscalar() + y = gammainc(x1, x2) + test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y]))) + assert np.isclose(test_func(np.inf, 1), sp.gammainc(np.inf, 1)) + assert np.isclose(test_func(1, np.inf), sp.gammainc(1, np.inf)) + assert np.isnan(test_func(np.inf, np.inf)) + + def test_gammaincc_python(): x1 = pt.dscalar() x2 = pt.dscalar() @@ -59,6 +69,16 @@ def test_gammaincc_nan_c(): assert np.isnan(test_func(-1, -1)) +def test_gammaincc_inf_c(): + x1 = pt.dscalar() + x2 = pt.dscalar() + y = gammaincc(x1, x2) + test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y]))) + assert np.isclose(test_func(np.inf, 1), sp.gammaincc(np.inf, 1)) + assert np.isclose(test_func(1, np.inf), sp.gammaincc(1, np.inf)) + assert np.isnan(test_func(np.inf, np.inf)) + + def test_gammal_nan_c(): x1 = pt.dscalar() x2 = pt.dscalar() From a2b242ce35e31df7e4d8eeed667690a4b621a7db Mon Sep 17 00:00:00 2001 From: amyoshino Date: Wed, 7 Feb 2024 00:05:26 -0300 Subject: [PATCH 3/3] fixing ruff pre-commit error --- pytensor/scalar/math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 6e2b18df07..edf03a393d 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -630,7 +630,7 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) - + def c_code_cache_version(self): v = super().c_code_cache_version() if v: @@ -683,7 +683,7 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) - + def c_code_cache_version(self): v = super().c_code_cache_version() if v: @@ -736,7 +736,7 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) - + def c_code_cache_version(self): v = super().c_code_cache_version() if v: