Skip to content

Commit 244fb97

Browse files
authored
Refactor get_tau_sigma and support lists of variables (#7185)
1 parent 47f6d9e commit 244fb97

File tree

2 files changed

+49
-42
lines changed

2 files changed

+49
-42
lines changed

pymc/distributions/continuous.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -229,32 +229,21 @@ def get_tau_sigma(tau=None, sigma=None):
229229
-----
230230
If neither tau nor sigma is provided, returns (1., 1.)
231231
"""
232-
if tau is None:
233-
if sigma is None:
234-
sigma = 1.0
235-
tau = 1.0
236-
else:
237-
if isinstance(sigma, Variable):
238-
# Keep tau negative, if sigma was negative, so that it will fail when used
239-
tau = (sigma**-2.0) * pt.sign(sigma)
240-
else:
241-
sigma_ = np.asarray(sigma)
242-
if np.any(sigma_ <= 0):
243-
raise ValueError("sigma must be positive")
244-
tau = sigma_**-2.0
245-
232+
if tau is not None and sigma is not None:
233+
raise ValueError("Can't pass both tau and sigma")
234+
if tau is None and sigma is None:
235+
sigma = pt.as_tensor_variable(1.0)
236+
tau = pt.as_tensor_variable(1.0)
237+
elif tau is None:
238+
sigma = pt.as_tensor_variable(sigma)
239+
# Keep tau negative, if sigma was negative, so that it will
240+
# fail when used
241+
tau = (sigma**-2.0) * pt.sign(sigma)
246242
else:
247-
if sigma is not None:
248-
raise ValueError("Can't pass both tau and sigma")
249-
else:
250-
if isinstance(tau, Variable):
251-
# Keep sigma negative, if tau was negative, so that it will fail when used
252-
sigma = pt.abs(tau) ** (-0.5) * pt.sign(tau)
253-
else:
254-
tau_ = np.asarray(tau)
255-
if np.any(tau_ <= 0):
256-
raise ValueError("tau must be positive")
257-
sigma = tau_**-0.5
243+
tau = pt.as_tensor_variable(tau)
244+
# Keep tau negative, if sigma was negative, so that it will
245+
# fail when used
246+
sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)
258247

259248
return tau, sigma
260249

tests/distributions/test_continuous.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import functools as ft
16-
import warnings
1716

1817
import numpy as np
1918
import numpy.testing as npt
@@ -998,26 +997,45 @@ def scipy_logcdf(value, mu, sigma, lower, upper):
998997
assert np.isinf(logp[2])
999998

1000999
def test_get_tau_sigma(self):
1001-
# Fail on warnings
1002-
with warnings.catch_warnings():
1003-
warnings.simplefilter("error")
1000+
sigma = np.array(2)
1001+
tau, _ = get_tau_sigma(sigma=sigma)
1002+
npt.assert_almost_equal(tau.eval(), 1.0 / sigma**2)
10041003

1005-
sigma = np.array(2)
1006-
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
1004+
tau = np.array(2)
1005+
_, sigma = get_tau_sigma(tau=tau)
1006+
npt.assert_almost_equal(sigma.eval(), tau**-0.5)
10071007

1008-
tau = np.array(2)
1009-
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
1008+
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
1009+
npt.assert_almost_equal(tau.eval(), -0.25)
10101010

1011-
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
1012-
npt.assert_almost_equal(tau.eval(), -0.25)
1011+
_, sigma = get_tau_sigma(tau=pt.constant(-2))
1012+
npt.assert_almost_equal(sigma.eval(), -1.0 / np.sqrt(2.0))
10131013

1014-
_, sigma = get_tau_sigma(tau=pt.constant(-2))
1015-
npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2))
1014+
sigma = [1, 2]
1015+
tau, _ = get_tau_sigma(sigma=sigma)
1016+
npt.assert_almost_equal(tau.eval(), 1.0 / np.array(sigma) ** 2)
10161017

1017-
sigma = [1, 2]
1018-
npt.assert_almost_equal(
1019-
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
1020-
)
1018+
# Test null arguments
1019+
tau, sigma = get_tau_sigma()
1020+
npt.assert_almost_equal(tau.eval(), 1.0)
1021+
npt.assert_almost_equal(sigma.eval(), 1.0)
1022+
1023+
# Test exception upon passing both sigma and tau
1024+
msg = "Can't pass both tau and sigma"
1025+
with pytest.raises(ValueError, match=msg):
1026+
_, _ = get_tau_sigma(sigma=1.0, tau=1.0)
1027+
1028+
# These are regression test for #6988: Check that get_tau_sigma works
1029+
# for lists of tensors
1030+
sigma = [pt.constant(2), pt.constant(2)]
1031+
expect_tau = np.array([0.25, 0.25])
1032+
tau, _ = get_tau_sigma(sigma=sigma)
1033+
npt.assert_almost_equal(tau.eval(), expect_tau)
1034+
1035+
tau = [pt.constant(2), pt.constant(2)]
1036+
expect_sigma = np.array([2.0, 2.0]) ** -0.5
1037+
_, sigma = get_tau_sigma(tau=tau)
1038+
npt.assert_almost_equal(sigma.eval(), expect_sigma)
10211039

10221040
@pytest.mark.parametrize(
10231041
"value,mu,sigma,nu,logp",
@@ -2042,7 +2060,7 @@ class TestStudentTLam(BaseTestDistributionRandom):
20422060
lam, sigma = get_tau_sigma(tau=2.0)
20432061
pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam}
20442062
expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "lam": sigma}
2045-
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma}
2063+
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma.eval()}
20462064
reference_dist = seeded_scipy_distribution_builder("t")
20472065
checks_to_run = ["check_pymc_params_match_rv_op"]
20482066

0 commit comments

Comments
 (0)