diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index 31cf813a5d..a7fa73b518 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -16,8 +16,15 @@ import aesara.tensor as at import numpy as np -from aesara.tensor.random.basic import bernoulli, binomial, categorical, nbinom, poisson +from aesara.tensor.random.basic import ( + BernoulliRV, + binomial, + categorical, + nbinom, + poisson, +) from scipy import stats +from scipy.special import expit from pymc3.aesaraf import floatX, intX, take_along_axis from pymc3.distributions.dist_math import ( @@ -32,7 +39,7 @@ normal_lcdf, ) from pymc3.distributions.distribution import Discrete -from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid, tround +from pymc3.math import log1mexp, log1pexp, logaddexp, logit, logsumexp, sigmoid, tround __all__ = [ "Binomial", @@ -332,6 +339,19 @@ def logcdf(self, value): ) +class BernoulliLogitRV(BernoulliRV): + name = "bernoulli_logit" + _print_name = ("BernLogit", "\\operatorname{BernLogit}") + + @classmethod + def rng_fn(cls, rng, logitp, size=None): + p = expit(logitp) + return stats.bernoulli.rvs(p, size=size, random_state=rng) + + +bernoulli_logit = BernoulliLogitRV() + + class Bernoulli(Discrete): R"""Bernoulli log-likelihood @@ -368,16 +388,29 @@ class Bernoulli(Discrete): ---------- p: float Probability of success (0 < p < 1). + logit_p: float + Alternative logit of sucess probability. """ - rv_op = bernoulli + rv_op = bernoulli_logit @classmethod def dist(cls, p=None, logit_p=None, *args, **kwargs): - p = at.as_tensor_variable(floatX(p)) - # mode = at.cast(tround(p), "int8") - return super().dist([p], **kwargs) + logit_p = cls.get_logitp(p=p, logit_p=logit_p) + logit_p = at.as_tensor_variable(floatX(logit_p)) + return super().dist([logit_p], **kwargs) - def logp(value, p): + @classmethod + def get_logitp(cls, p=None, logit_p=None): + if p is not None and logit_p is not None: + raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") + elif p is None and logit_p is None: + raise ValueError("Incompatible parametrization. Must specify either p or logit_p.") + + if logit_p is None: + logit_p = logit(p) + return logit_p + + def logp(value, logit_p): r""" Calculate log-probability of Bernoulli distribution at specified value. @@ -391,19 +424,15 @@ def logp(value, p): ------- TensorVariable """ - # if self._is_logit: - # lp = at.switch(value, self._logit_p, -self._logit_p) - # return -log1pexp(-lp) - # else: + lp = at.switch(value, -logit_p, logit_p) return bound( - at.switch(value, at.log(p), at.log(1 - p)), - value >= 0, + -log1pexp(lp), + 0 <= value, value <= 1, - p >= 0, - p <= 1, + ~at.isnan(logit_p), ) - def logcdf(value, p): + def logcdf(value, logit_p): """ Compute the log of the cumulative distribution function for Bernoulli distribution at the specified value. @@ -422,12 +451,11 @@ def logcdf(value, p): return bound( at.switch( at.lt(value, 1), - at.log1p(-p), + -log1pexp(logit_p), 0, ), 0 <= value, - 0 <= p, - p <= 1, + ~at.isnan(logit_p), ) def _distr_parameters_for_repr(self): diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 84ed8d93c3..f20de16368 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1221,7 +1221,6 @@ def scipy_mu_alpha_logcdf(value, mu, alpha): n_samples=10, ) - @pytest.mark.xfail(reason="Distribution not refactored yet") @pytest.mark.parametrize( "mu, p, alpha, n, expected", [ @@ -1522,21 +1521,6 @@ def test_beta_binomial_selfconsistency(self): {"alpha": Rplus, "beta": Rplus, "n": NatSmall}, ) - @pytest.mark.xfail(reason="Bernoulli logit_p not refactored yet") - def test_bernoulli_logit_p(self): - self.check_logp( - Bernoulli, - Bool, - {"logit_p": R}, - lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)), - ) - self.check_logcdf( - Bernoulli, - Bool, - {"logit_p": R}, - lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)), - ) - def test_bernoulli(self): self.check_logp( Bernoulli, @@ -1556,6 +1540,32 @@ def test_bernoulli(self): {"p": Unit}, ) + def test_bernoulli_logitp(self): + self.check_logp( + Bernoulli, + Bool, + {"logit_p": R}, + lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)), + ) + self.check_logcdf( + Bernoulli, + Bool, + {"logit_p": R}, + lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)), + ) + + @pytest.mark.parametrize( + "p, logit_p, expected", + [ + (None, None, "Must specify either p or logit_p."), + (0.5, 0.5, "Can't specify both p and logit_p."), + ], + ) + def test_bernoulli_init_fail(self, p, logit_p, expected): + with Model(): + with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"): + Bernoulli("x", p=p, logit_p=logit_p) + @pytest.mark.xfail(reason="Distribution not refactored yet") def test_discrete_weibull(self): self.check_logp( diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 3d677460a0..b5ecb44182 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -731,10 +731,15 @@ def test_beta_binomial(self): def _beta_bin(self, n, alpha, beta, size=None): return st.binom.rvs(n, st.beta.rvs(a=alpha, b=beta, size=size)) - @pytest.mark.skip(reason="This test is covered by Aesara") def test_bernoulli(self): pymc3_random_discrete( - pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p=None: st.bernoulli.rvs(p, size=size) + pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p: st.bernoulli.rvs(p, size=size) + ) + + pymc3_random_discrete( + pm.Bernoulli, + {"logit_p": R}, + ref_rand=lambda size, logit_p: st.bernoulli.rvs(expit(logit_p), size=size), ) @pytest.mark.skip(reason="This test is covered by Aesara") diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index b79f9eaacb..5155150d4b 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -51,7 +51,6 @@ def get_city_data(): return data.merge(unique, "inner", on="fips") -@pytest.mark.xfail(reason="Bernoulli distribution not refactored") class TestARM5_4(SeededTest): def build_model(self): data = pd.read_csv( @@ -68,7 +67,7 @@ def build_model(self): P["1"] = 1 with pm.Model() as model: - effects = pm.Normal("effects", mu=0, sigma=100, shape=len(P.columns)) + effects = pm.Normal("effects", mu=0, sigma=100, size=len(P.columns)) logit_p = at.dot(floatX(np.array(P)), effects) pm.Bernoulli("s", logit_p=logit_p, observed=floatX(data.switch.values)) return model