From 9e8601c1f21592fb96b553bcfbd378424c92acab Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Thu, 29 Feb 2024 17:56:32 +0000 Subject: [PATCH 01/10] This commit is a partial solution to the problems presented in PyMC github issue #7130. The first part of the issue was that certain functions like inverse trigonometric functions didn't seem to work, however on a clean install for this commit and pull request, that issue disappeared and the functions work as intended. These functions were missing among others from the website documentation which is the second part of the github issue and is what this particular git commit concerns. Note I have included tround although other codes reflect this as being deprecated but still able to be used. I have also rearranged the ordering to better reflect similar functions being bunched together. --- docs/source/api/math.rst | 109 +++++++++++++++++++++++++++------------ 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/docs/source/api/math.rst b/docs/source/api/math.rst index 67b487194d..610e9555c1 100644 --- a/docs/source/api/math.rst +++ b/docs/source/api/math.rst @@ -19,8 +19,10 @@ Functions exposed in pymc namespace invlogit probit invprobit + logaddexp logsumexp + Functions exposed in pymc.math ------------------------------ @@ -28,47 +30,90 @@ Functions exposed in pymc.math .. autosummary:: :toctree: generated/ - dot - constant - flatten - zeros_like - ones_like - stack - concatenate - sum + abs prod - lt - gt - le - ge + dot eq neq - switch - clip - where - and_ - or_ - abs + ge + gt + le + lt exp log - cos + sgn + sqr + sqrt + sum + ceil + floor sin - tan - cosh sinh + arcsin + arcsinh + cos + cosh + arccos + arccosh + tan tanh - sqr - sqrt - erf - erfinv - dot + arctan + arctanh + cumprod + cumsum + matmul + and_ + broadcast_to + clip + concatenate + flatten + or_ + stack + switch + where + flatten_list + constant + max maximum + mean + min minimum - sgn - ceil - floor - matrix_inverse - sigmoid + round + tround + erf + erfc + erfcinv + erfinv + log1pexp + log1mexp + log1mexp_numpy + logaddexp logsumexp - invlogit + logdiffexp + logdiffexp_numpy logit + invlogit + probit + invprobit + sigmoid + softmax + log_softmax + logbern + full + full_like + ones + ones_like + zeros + zeros_like + kronecker + cartesian + kron_dot + kron_solve_lower + kron_solve_upper + kron_diag + flat_outer + expand_packed_triangular + batched_diag + block_diagonal + matrix_inverse + logdet From a35a40eb9fef9fca7657d8660c91d14b31235410 Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Mon, 25 Mar 2024 10:19:24 +0000 Subject: [PATCH 02/10] Removed numpy functions and tround from math.rst --- docs/source/api/math.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/api/math.rst b/docs/source/api/math.rst index 610e9555c1..260e58f214 100644 --- a/docs/source/api/math.rst +++ b/docs/source/api/math.rst @@ -79,18 +79,15 @@ Functions exposed in pymc.math min minimum round - tround erf erfc erfcinv erfinv log1pexp log1mexp - log1mexp_numpy logaddexp logsumexp logdiffexp - logdiffexp_numpy logit invlogit probit From a97100557b46749f63491f5b53d3ad6095d9492d Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Mon, 25 Mar 2024 11:10:56 +0000 Subject: [PATCH 03/10] Edits to math.py. Removed tround, replaced round with pytensor.tensor import, and added FutureWarning to log1mexp_numpy and logdiffexp_numpy. --- pymc/math.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/pymc/math.py b/pymc/math.py index 7fe8d1e5e5..cf8684b1f4 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -73,6 +73,7 @@ ones_like, or_, prod, + round, sgn, sigmoid, sin, @@ -178,6 +179,7 @@ "expand_packed_triangular", "batched_diag", "block_diagonal", + "round", ] @@ -272,20 +274,6 @@ def kron_diag(*diags): return reduce(flat_outer, diags) -def round(*args, **kwargs): - """ - Temporary function to silence round warning in PyTensor. Please remove - when the warning disappears. - """ - kwargs["mode"] = "half_to_even" - return pt.round(*args, **kwargs) - - -def tround(*args, **kwargs): - warnings.warn("tround is deprecated. Use round instead.") - return round(*args, **kwargs) - - def logdiffexp(a, b): """log(exp(a) - exp(b))""" return a + pt.log1mexp(b - a) @@ -293,6 +281,11 @@ def logdiffexp(a, b): def logdiffexp_numpy(a, b): """log(exp(a) - exp(b))""" + warnings.warn( + "pymc.math.logdiffexp_numpy is being deprecated. Use logdiffexp instead.", + FutureWarning, + stacklevel=2, + ) return a + log1mexp_numpy(b - a, negative_input=True) @@ -341,6 +334,11 @@ def log1mexp_numpy(x, *, negative_input=False): For details, see https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ + warnings.warn( + "pymc.math.log1mexp_numpy is being deprecated. Use log1mexp instead.", + FutureWarning, + stacklevel=2, + ) x = np.asarray(x, dtype="float") if not negative_input: From 0f4300562a6869321c69d81e102391eb5544827b Mon Sep 17 00:00:00 2001 From: Brandon Horsley <56071064+brandonhorsley@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:20:23 +0000 Subject: [PATCH 04/10] Update pymc/math.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/math.py b/pymc/math.py index cf8684b1f4..48c0134321 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -335,7 +335,7 @@ def log1mexp_numpy(x, *, negative_input=False): https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ warnings.warn( - "pymc.math.log1mexp_numpy is being deprecated. Use log1mexp instead.", + "pymc.math.log1mexp_numpy is being deprecated.", FutureWarning, stacklevel=2, ) From efa45433dc775cec067851d10362fbe6cc07b4f4 Mon Sep 17 00:00:00 2001 From: Brandon Horsley <56071064+brandonhorsley@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:20:43 +0000 Subject: [PATCH 05/10] Update pymc/math.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/math.py b/pymc/math.py index 48c0134321..b85ffe63ce 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -282,7 +282,7 @@ def logdiffexp(a, b): def logdiffexp_numpy(a, b): """log(exp(a) - exp(b))""" warnings.warn( - "pymc.math.logdiffexp_numpy is being deprecated. Use logdiffexp instead.", + "pymc.math.logdiffexp_numpy is being deprecated.", FutureWarning, stacklevel=2, ) From 8166fe075608a9ef0b12223184c7eca4ed07d397 Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Mon, 25 Mar 2024 12:44:33 +0000 Subject: [PATCH 06/10] Further FutureWarnings applied to _numpy functions Searched PyMC codebase for occurrences of _numpy functions and added explicit warnings. Most of these were for log1mexp_numpy which had quite a few involvements in test codes. Also note that on test_math.py the _numpy functions are explicitly imported and so when deprecated those imports will have to be removed too. --- tests/distributions/test_continuous.py | 1 + tests/test_math.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 901618dc28..dac5086fa4 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -418,6 +418,7 @@ def scipy_log_pdf(value, a, b): return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a) def scipy_log_cdf(value, a, b): + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True) check_logp( diff --git a/tests/test_math.py b/tests/test_math.py index 544bf4ce93..347290a289 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -29,10 +29,10 @@ kron_solve_lower, kronecker, log1mexp, - log1mexp_numpy, + log1mexp_numpy, # to be deprecated logdet, logdiffexp, - logdiffexp_numpy, + logdiffexp_numpy, # to be deprecated probit, ) from pymc.pytensorf import floatX @@ -148,6 +148,8 @@ def test_log1mexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) + + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) actual_ = log1mexp_numpy(-vals, negative_input=True) npt.assert_allclose(actual_, expected) # Check that input was not changed in place @@ -158,10 +160,12 @@ def test_log1mexp_numpy_no_warning(): """Assert RuntimeWarning is not raised for very small numbers""" with warnings.catch_warnings(): warnings.simplefilter("error") + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) log1mexp_numpy(-1e-25, negative_input=True) def test_log1mexp_numpy_integer_input(): + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) @@ -170,10 +174,12 @@ def test_log1mexp_deprecation_warnings(): FutureWarning, match="pymc.math.log1mexp_numpy will expect a negative input", ): + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) res_pos = log1mexp_numpy(2) with warnings.catch_warnings(): warnings.simplefilter("error") + warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) res_neg = log1mexp_numpy(-2, negative_input=True) with pytest.warns( @@ -196,7 +202,7 @@ def test_logdiffexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - + warnings.warn("pymc.math.logdiffexp_numpy is being deprecated.", FutureWarning) assert np.allclose(logdiffexp_numpy(a, b), 0) assert np.allclose(logdiffexp(a, b).eval(), 0) From 4045bc2ea2abbb0357e728ff4d54757004ccc688 Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Mon, 25 Mar 2024 13:34:03 +0000 Subject: [PATCH 07/10] Undid test_math edits, test_continuous change This commit undoes the changes to test_math.py by removing the FutureWarning warnings for the _numpy functions. Note that when these tests are removed the _numpy functions are imported explicitly in this file. Replaced the log1mexp_numpy use in test_continuous.py by just substituting in the maths from the function to avoid calling the function. --- tests/distributions/test_continuous.py | 10 ++++++++-- tests/test_math.py | 10 ++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index dac5086fa4..e576395bb8 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -418,8 +418,14 @@ def scipy_log_pdf(value, a, b): return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a) def scipy_log_cdf(value, a, b): - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) - return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True) + x = b * np.log1p(-(value**a)) + x = np.asarray(x, dtype="float") + out = np.empty_like(x) + mask = x < -0.6931471805599453 # log(1/2) + out[mask] = np.log1p(-np.exp(x[mask])) + mask = ~mask + out[mask] = np.log(-np.expm1(x[mask])) + return out check_logp( pm.Kumaraswamy, diff --git a/tests/test_math.py b/tests/test_math.py index 347290a289..ef61bda441 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -29,10 +29,10 @@ kron_solve_lower, kronecker, log1mexp, - log1mexp_numpy, # to be deprecated + log1mexp_numpy, logdet, logdiffexp, - logdiffexp_numpy, # to be deprecated + logdiffexp_numpy, probit, ) from pymc.pytensorf import floatX @@ -148,8 +148,6 @@ def test_log1mexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) - - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) actual_ = log1mexp_numpy(-vals, negative_input=True) npt.assert_allclose(actual_, expected) # Check that input was not changed in place @@ -160,12 +158,10 @@ def test_log1mexp_numpy_no_warning(): """Assert RuntimeWarning is not raised for very small numbers""" with warnings.catch_warnings(): warnings.simplefilter("error") - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) log1mexp_numpy(-1e-25, negative_input=True) def test_log1mexp_numpy_integer_input(): - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) @@ -174,12 +170,10 @@ def test_log1mexp_deprecation_warnings(): FutureWarning, match="pymc.math.log1mexp_numpy will expect a negative input", ): - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) res_pos = log1mexp_numpy(2) with warnings.catch_warnings(): warnings.simplefilter("error") - warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning) res_neg = log1mexp_numpy(-2, negative_input=True) with pytest.warns( From b14e0009cb3cb65d3f76fe3e65bde3ef8e17a6d6 Mon Sep 17 00:00:00 2001 From: Brandon Horsley Date: Mon, 25 Mar 2024 14:30:02 +0000 Subject: [PATCH 08/10] Removed a FutureWarning I forgot to remove --- tests/test_math.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_math.py b/tests/test_math.py index ef61bda441..cbdde46a66 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -196,7 +196,6 @@ def test_logdiffexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - warnings.warn("pymc.math.logdiffexp_numpy is being deprecated.", FutureWarning) assert np.allclose(logdiffexp_numpy(a, b), 0) assert np.allclose(logdiffexp(a, b).eval(), 0) From 7f3eb36291cc5caabb7c8d55a65082c3c609e41f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 25 Mar 2024 16:02:31 +0100 Subject: [PATCH 09/10] refactor helper function --- tests/distributions/test_continuous.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index e576395bb8..91d6cdbaf6 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -417,15 +417,11 @@ def test_kumaraswamy(self): def scipy_log_pdf(value, a, b): return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a) + def log1mexp(x): + return np.log1p(-np.exp(x)) if x < np.log(0.5) else np.log(-np.expm1(x)) + def scipy_log_cdf(value, a, b): - x = b * np.log1p(-(value**a)) - x = np.asarray(x, dtype="float") - out = np.empty_like(x) - mask = x < -0.6931471805599453 # log(1/2) - out[mask] = np.log1p(-np.exp(x[mask])) - mask = ~mask - out[mask] = np.log(-np.expm1(x[mask])) - return out + return log1mexp(b * np.log1p(-(value**a))) check_logp( pm.Kumaraswamy, From 43ec2ed5f1af55463fab515f5847f33ff30ee71a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 25 Mar 2024 16:35:22 +0100 Subject: [PATCH 10/10] Tweak tests --- tests/test_math.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/test_math.py b/tests/test_math.py index cbdde46a66..40c3b70db5 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -145,45 +145,46 @@ def test_log1mexp(): ) actual = pt.log1mexp(-vals).eval() npt.assert_allclose(actual, expected) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) - actual_ = log1mexp_numpy(-vals, negative_input=True) + with pytest.warns(FutureWarning, match="deprecated"): + actual_ = log1mexp_numpy(-vals, negative_input=True) npt.assert_allclose(actual_, expected) # Check that input was not changed in place npt.assert_allclose(vals, vals_) +@pytest.mark.filterwarnings("error") def test_log1mexp_numpy_no_warning(): """Assert RuntimeWarning is not raised for very small numbers""" - with warnings.catch_warnings(): - warnings.simplefilter("error") + with pytest.warns(FutureWarning, match="deprecated"): log1mexp_numpy(-1e-25, negative_input=True) def test_log1mexp_numpy_integer_input(): - assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) + with pytest.warns(FutureWarning, match="deprecated"): + assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) +@pytest.mark.filterwarnings("error") def test_log1mexp_deprecation_warnings(): - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp_numpy will expect a negative input", - ): - res_pos = log1mexp_numpy(2) + with pytest.warns(FutureWarning, match="deprecated"): + with pytest.warns( + FutureWarning, + match="pymc.math.log1mexp_numpy will expect a negative input", + ): + res_pos = log1mexp_numpy(2) - with warnings.catch_warnings(): - warnings.simplefilter("error") res_neg = log1mexp_numpy(-2, negative_input=True) - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp will expect a negative input", - ): - res_pos_at = log1mexp(2).eval() + with pytest.warns( + FutureWarning, + match="pymc.math.log1mexp will expect a negative input", + ): + res_pos_at = log1mexp(2).eval() - with warnings.catch_warnings(): - warnings.simplefilter("error") res_neg_at = log1mexp(-2, negative_input=True).eval() assert np.isclose(res_pos, res_neg) @@ -196,7 +197,8 @@ def test_logdiffexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - assert np.allclose(logdiffexp_numpy(a, b), 0) + with pytest.warns(FutureWarning, match="deprecated"): + assert np.allclose(logdiffexp_numpy(a, b), 0) assert np.allclose(logdiffexp(a, b).eval(), 0)