|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import functools as ft
|
16 |
| -import warnings |
17 | 16 |
|
18 | 17 | import numpy as np
|
19 | 18 | import numpy.testing as npt
|
@@ -998,26 +997,45 @@ def scipy_logcdf(value, mu, sigma, lower, upper):
|
998 | 997 | assert np.isinf(logp[2])
|
999 | 998 |
|
1000 | 999 | def test_get_tau_sigma(self):
|
1001 |
| - # Fail on warnings |
1002 |
| - with warnings.catch_warnings(): |
1003 |
| - warnings.simplefilter("error") |
| 1000 | + sigma = np.array(2) |
| 1001 | + tau, _ = get_tau_sigma(sigma=sigma) |
| 1002 | + npt.assert_almost_equal(tau.eval(), 1.0 / sigma**2) |
1004 | 1003 |
|
1005 |
| - sigma = np.array(2) |
1006 |
| - npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) |
| 1004 | + tau = np.array(2) |
| 1005 | + _, sigma = get_tau_sigma(tau=tau) |
| 1006 | + npt.assert_almost_equal(sigma.eval(), tau**-0.5) |
1007 | 1007 |
|
1008 |
| - tau = np.array(2) |
1009 |
| - npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) |
| 1008 | + tau, _ = get_tau_sigma(sigma=pt.constant(-2)) |
| 1009 | + npt.assert_almost_equal(tau.eval(), -0.25) |
1010 | 1010 |
|
1011 |
| - tau, _ = get_tau_sigma(sigma=pt.constant(-2)) |
1012 |
| - npt.assert_almost_equal(tau.eval(), -0.25) |
| 1011 | + _, sigma = get_tau_sigma(tau=pt.constant(-2)) |
| 1012 | + npt.assert_almost_equal(sigma.eval(), -1.0 / np.sqrt(2.0)) |
1013 | 1013 |
|
1014 |
| - _, sigma = get_tau_sigma(tau=pt.constant(-2)) |
1015 |
| - npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2)) |
| 1014 | + sigma = [1, 2] |
| 1015 | + tau, _ = get_tau_sigma(sigma=sigma) |
| 1016 | + npt.assert_almost_equal(tau.eval(), 1.0 / np.array(sigma) ** 2) |
1016 | 1017 |
|
1017 |
| - sigma = [1, 2] |
1018 |
| - npt.assert_almost_equal( |
1019 |
| - get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] |
1020 |
| - ) |
| 1018 | + # Test null arguments |
| 1019 | + tau, sigma = get_tau_sigma() |
| 1020 | + npt.assert_almost_equal(tau.eval(), 1.0) |
| 1021 | + npt.assert_almost_equal(sigma.eval(), 1.0) |
| 1022 | + |
| 1023 | + # Test exception upon passing both sigma and tau |
| 1024 | + msg = "Can't pass both tau and sigma" |
| 1025 | + with pytest.raises(ValueError, match=msg): |
| 1026 | + _, _ = get_tau_sigma(sigma=1.0, tau=1.0) |
| 1027 | + |
| 1028 | + # These are regression test for #6988: Check that get_tau_sigma works |
| 1029 | + # for lists of tensors |
| 1030 | + sigma = [pt.constant(2), pt.constant(2)] |
| 1031 | + expect_tau = np.array([0.25, 0.25]) |
| 1032 | + tau, _ = get_tau_sigma(sigma=sigma) |
| 1033 | + npt.assert_almost_equal(tau.eval(), expect_tau) |
| 1034 | + |
| 1035 | + tau = [pt.constant(2), pt.constant(2)] |
| 1036 | + expect_sigma = np.array([2.0, 2.0]) ** -0.5 |
| 1037 | + _, sigma = get_tau_sigma(tau=tau) |
| 1038 | + npt.assert_almost_equal(sigma.eval(), expect_sigma) |
1021 | 1039 |
|
1022 | 1040 | @pytest.mark.parametrize(
|
1023 | 1041 | "value,mu,sigma,nu,logp",
|
@@ -2042,7 +2060,7 @@ class TestStudentTLam(BaseTestDistributionRandom):
|
2042 | 2060 | lam, sigma = get_tau_sigma(tau=2.0)
|
2043 | 2061 | pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam}
|
2044 | 2062 | expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "lam": sigma}
|
2045 |
| - reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma} |
| 2063 | + reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma.eval()} |
2046 | 2064 | reference_dist = seeded_scipy_distribution_builder("t")
|
2047 | 2065 | checks_to_run = ["check_pymc_params_match_rv_op"]
|
2048 | 2066 |
|
|
0 commit comments