diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index cad0548b6e..36a74bbdc8 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -155,6 +155,7 @@ jobs: - | pymc/tests/test_initial_point.py pymc/tests/test_distributions_random.py + pymc/tests/test_distributions_moments.py pymc/tests/test_distributions_timeseries.py - | pymc/tests/test_parallel_sampling.py diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b1d494bec3..ac1598d759 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -85,6 +85,7 @@ def polyagamma_cdf(*args, **kwargs): zvalue, ) from pymc.distributions.distribution import Continuous +from pymc.distributions.shape_utils import rv_size_is_none from pymc.math import logdiffexp, logit from pymc.util import UNSET @@ -290,6 +291,13 @@ def dist(cls, lower=0, upper=1, **kwargs): upper = at.as_tensor_variable(floatX(upper)) return super().dist([lower, upper], **kwargs) + def get_moment(rv, size, lower, upper): + lower, upper = at.broadcast_arrays(lower, upper) + moment = (lower + upper) / 2 + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + def logcdf(value, lower, upper): """ Compute the log of the cumulative distribution function for Uniform distribution @@ -315,11 +323,6 @@ def logcdf(value, lower, upper): ), ) - def get_moment(value, size, lower, upper): - lower = at.full(size, lower, dtype=aesara.config.floatX) - upper = at.full(size, upper, dtype=aesara.config.floatX) - return (lower + upper) / 2 - class FlatRV(RandomVariable): name = "flat" @@ -353,8 +356,8 @@ def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) return res - def get_moment(rv, size, *rv_inputs): - return at.zeros(size, dtype=aesara.config.floatX) + def get_moment(rv, size): + return at.zeros(size) def logp(value): """ @@ -421,8 +424,8 @@ def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) return res - def get_moment(value_var, size, *rv_inputs): - return at.ones(size, dtype=aesara.config.floatX) + def get_moment(rv, size): + return at.ones(size) def logp(value): """ @@ -540,6 +543,12 @@ def dist(cls, mu=0, sigma=None, tau=None, sd=None, no_assert=False, **kwargs): return super().dist([mu, sigma], **kwargs) + def get_moment(rv, size, mu, sigma): + mu, _ = at.broadcast_arrays(mu, sigma) + if not rv_size_is_none(size): + mu = at.full(size, mu) + return mu + def logcdf(value, mu, sigma): """ Compute the log of the cumulative distribution function for Normal distribution @@ -560,9 +569,6 @@ def logcdf(value, mu, sigma): 0 < sigma, ) - def get_moment(value_var, size, mu, sigma): - return at.full(size, mu, dtype=aesara.config.floatX) - class TruncatedNormalRV(RandomVariable): name = "truncated_normal" @@ -691,19 +697,35 @@ def dist( assert_negative_support(sigma, "sigma", "TruncatedNormal") assert_negative_support(tau, "tau", "TruncatedNormal") - # if lower is None and upper is None: - # initval = mu - # elif lower is None and upper is not None: - # initval = upper - 1.0 - # elif lower is not None and upper is None: - # initval = lower + 1.0 - # else: - # initval = (lower + upper) / 2 - lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf) upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf) return super().dist([mu, sigma, lower, upper], **kwargs) + def get_moment(rv, size, mu, sigma, lower, upper): + mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper) + moment = at.switch( + at.isinf(lower), + at.switch( + at.isinf(upper), + # lower = -inf, upper = inf + mu, + # lower = -inf, upper = x + upper - 1, + ), + at.switch( + at.isinf(upper), + # lower = x, upper = inf + lower + 1, + # lower = x, upper = x + (lower + upper) / 2, + ), + ) + + if not rv_size_is_none(size): + moment = at.full(size, moment) + + return moment + def logp( value, mu: Union[float, np.ndarray, TensorVariable], @@ -828,6 +850,12 @@ def dist(cls, sigma=None, tau=None, sd=None, *args, **kwargs): return super().dist([0.0, sigma], **kwargs) + def get_moment(rv, size, loc, sigma): + moment = loc + sigma + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + def logcdf(value, loc, sigma): """ Compute the log of the cumulative distribution function for HalfNormal distribution @@ -850,9 +878,6 @@ def logcdf(value, loc, sigma): 0 < sigma, ) - def _distr_parameters_for_repr(self): - return ["sigma"] - class WaldRV(RandomVariable): name = "wald" diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 0fe2296da4..08b1a8afed 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -43,6 +43,7 @@ ) from pymc.distributions.distribution import Discrete from pymc.distributions.logprob import _logcdf +from pymc.distributions.shape_utils import rv_size_is_none from pymc.math import sigmoid __all__ = [ @@ -352,6 +353,11 @@ def dist(cls, p=None, logit_p=None, *args, **kwargs): p = at.as_tensor_variable(floatX(p)) return super().dist([p], **kwargs) + def get_moment(rv, size, p): + if not rv_size_is_none(size): + p = at.full(size, p) + return at.switch(p < 0.5, 0, 1) + def logp(value, p): r""" Calculate log-probability of Bernoulli distribution at specified value. @@ -402,13 +408,6 @@ def logcdf(value, p): p <= 1, ) - def get_moment(value, size, p): - p = at.full(size, p) - return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value)) - - def _distr_parameters_for_repr(self): - return ["p"] - class DiscreteWeibullRV(RandomVariable): name = "discrete_weibull" diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 449cc61da7..4d03de1587 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence import aesara -import numpy as np from aeppl.logprob import _logprob from aesara.tensor.basic import as_tensor_variable @@ -371,7 +370,7 @@ def get_moment(rv: TensorVariable) -> TensorVariable: for which the value is to be derived. """ size = rv.owner.inputs[1] - return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]) + return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]).astype(rv.dtype) class Discrete(Distribution): diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index c4b0cdf45d..f74b61e7a0 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -24,7 +24,7 @@ import numpy as np -from aesara.graph.basic import Variable +from aesara.graph.basic import Constant, Variable from aesara.tensor.var import TensorVariable from pymc.aesaraf import change_rv_size, pandas_to_array @@ -37,6 +37,7 @@ "get_broadcastable_dist_samples", "broadcast_distribution_samples", "broadcast_dist_samples_to", + "rv_size_is_none", ] @@ -674,3 +675,8 @@ def maybe_resize( ) return rv_out + + +def rv_size_is_none(size: Variable) -> bool: + """Check wether an rv size is None (ie., at.Constant([]))""" + return isinstance(size, Constant) and size.data.size == 0 diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py new file mode 100644 index 0000000000..0648f4b8fe --- /dev/null +++ b/pymc/tests/test_distributions_moments.py @@ -0,0 +1,144 @@ +import numpy as np +import pytest + +from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform +from pymc.distributions import HalfNormal +from pymc.distributions.shape_utils import rv_size_is_none +from pymc.initial_point import make_initial_point_fn +from pymc.model import Model + + +def test_rv_size_is_none(): + rv = Normal.dist(0, 1, size=None) + assert rv_size_is_none(rv.owner.inputs[1]) + + rv = Normal.dist(0, 1, size=1) + assert not rv_size_is_none(rv.owner.inputs[1]) + + size = Bernoulli.dist(0.5) + rv = Normal.dist(0, 1, size=size) + assert not rv_size_is_none(rv.owner.inputs[1]) + + size = Normal.dist(0, 1).size + rv = Normal.dist(0, 1, size=size) + assert not rv_size_is_none(rv.owner.inputs[1]) + + +def assert_moment_is_expected(model, expected): + fn = make_initial_point_fn( + model=model, + return_transformed=False, + default_strategy="moment", + ) + result = fn(0)["x"] + expected = np.asarray(expected) + try: + random_draw = model["x"].eval() + except NotImplementedError: + random_draw = result + assert result.shape == expected.shape == random_draw.shape + assert np.allclose(result, expected) + + +@pytest.mark.parametrize( + "size, expected", + [ + (None, 0), + (5, np.zeros(5)), + ((2, 5), np.zeros((2, 5))), + ], +) +def test_flat_moment(size, expected): + with Model() as model: + Flat("x", size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "size, expected", + [ + (None, 1), + (5, np.ones(5)), + ((2, 5), np.ones((2, 5))), + ], +) +def test_halfflat_moment(size, expected): + with Model() as model: + HalfFlat("x", size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "lower, upper, size, expected", + [ + (-1, 1, None, 0), + (-1, 1, 5, np.zeros(5)), + (0, np.arange(1, 6), None, np.arange(1, 6) / 2), + (0, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6) / 2)), + ], +) +def test_uniform_moment(lower, upper, size, expected): + with Model() as model: + Uniform("x", lower=lower, upper=upper, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "mu, sigma, size, expected", + [ + (0, 1, None, 0), + (0, np.ones(5), None, np.zeros(5)), + (np.arange(5), 1, None, np.arange(5)), + (np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))), + ], +) +def test_normal_moment(mu, sigma, size, expected): + with Model() as model: + Normal("x", mu=mu, sigma=sigma, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "sigma, size, expected", + [ + (1, None, 1), + (1, 5, np.ones(5)), + (np.arange(5), None, np.arange(5)), + (np.arange(5), (2, 5), np.full((2, 5), np.arange(5))), + ], +) +def test_halfnormal_moment(sigma, size, expected): + with Model() as model: + HalfNormal("x", sigma=sigma, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None") +@pytest.mark.parametrize( + "mu, sigma, lower, upper, size, expected", + [ + (0.9, 1, -1, 1, None, 0), + (0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)), + (np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)), + (1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)), + ], +) +def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected): + with Model() as model: + TruncatedNormal("x", mu=mu, sigma=sigma, lower=lower, upper=upper, size=size) + assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "p, size, expected", + [ + (0.3, None, 0), + (0.9, 5, np.ones(5)), + (np.linspace(0, 1, 4), None, [0, 0, 1, 1]), + (np.linspace(0, 1, 4), (2, 4), np.full((2, 4), [0, 0, 1, 1])), + ], +) +def test_bernoulli_moment(p, size, expected): + with Model() as model: + Bernoulli("x", p=p, size=size) + assert_moment_is_expected(model, expected)