Skip to content

Commit 4c64eb9

Browse files
committed
Don't use check_parameters in get_tau_sigma.
The default use of `check_parameters` indicates that an expression can be replaced by -inf, if the constraints aren't met. Instead, if `can_be_replaced_by_ninf=False`, sampling would fail for negative tau/sigma. To avoid this, the conversion now returns the right value for positive tau or sigma, but negative images if the inputs were negative. The methods that then validate the paramters (such as logp, logcdf, random), can later catch this.
1 parent 3f2a1da commit 4c64eb9

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,24 +234,26 @@ def get_tau_sigma(tau=None, sigma=None):
234234
tau = 1.0
235235
else:
236236
if isinstance(sigma, Variable):
237-
sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0")
237+
# Keep tau negative, if sigma was negative, so that it will fail when used
238+
tau = (sigma**-2.0) * pt.sgn(sigma)
238239
else:
239240
sigma_ = np.asarray(sigma)
240241
if np.any(sigma_ <= 0):
241242
raise ValueError("sigma must be positive")
242-
tau = sigma_**-2.0
243+
tau = sigma_**-2.0
243244

244245
else:
245246
if sigma is not None:
246247
raise ValueError("Can't pass both tau and sigma")
247248
else:
248249
if isinstance(tau, Variable):
249-
tau_ = check_parameters(tau, tau > 0, msg="tau > 0")
250+
# Keep sigma negative, if tau was negative, so that it will fail when used
251+
sigma = pt.abs(tau) ** (-0.5) * pt.sgn(tau)
250252
else:
251253
tau_ = np.asarray(tau)
252254
if np.any(tau_ <= 0):
253255
raise ValueError("tau must be positive")
254-
sigma = tau_**-0.5
256+
sigma = tau_**-0.5
255257

256258
return floatX(tau), floatX(sigma)
257259

tests/distributions/test_continuous.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools as ft
16+
import warnings
1617

1718
import numpy as np
1819
import numpy.testing as npt
@@ -890,24 +891,26 @@ def scipy_logp(value, mu, sigma, lower, upper):
890891
assert np.isinf(logp[2])
891892

892893
def test_get_tau_sigma(self):
893-
sigma = np.array(2)
894-
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
894+
# Fail on warnings
895+
with warnings.catch_warnings():
896+
warnings.simplefilter("error")
895897

896-
tau = np.array(2)
897-
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
898+
sigma = np.array(2)
899+
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
898900

899-
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
900-
with pytest.raises(ParameterValueError):
901-
tau.eval()
901+
tau = np.array(2)
902+
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
902903

903-
_, sigma = get_tau_sigma(tau=pt.constant(-2))
904-
with pytest.raises(ParameterValueError):
905-
sigma.eval()
904+
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
905+
npt.assert_almost_equal(tau.eval(), -0.25)
906906

907-
sigma = [1, 2]
908-
npt.assert_almost_equal(
909-
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
910-
)
907+
_, sigma = get_tau_sigma(tau=pt.constant(-2))
908+
npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2))
909+
910+
sigma = [1, 2]
911+
npt.assert_almost_equal(
912+
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
913+
)
911914

912915
@pytest.mark.parametrize(
913916
"value,mu,sigma,nu,logp",

0 commit comments

Comments
 (0)