Skip to content

Add tests for distributions moments #5087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ jobs:
- |
pymc/tests/test_initial_point.py
pymc/tests/test_distributions_random.py
pymc/tests/test_distributions_moments.py
pymc/tests/test_distributions_timeseries.py
- |
pymc/tests/test_parallel_sampling.py
Expand Down
73 changes: 49 additions & 24 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
zvalue,
)
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.math import logdiffexp, logit
from pymc.util import UNSET

Expand Down Expand Up @@ -290,6 +291,13 @@ def dist(cls, lower=0, upper=1, **kwargs):
upper = at.as_tensor_variable(floatX(upper))
return super().dist([lower, upper], **kwargs)

def get_moment(rv, size, lower, upper):
lower, upper = at.broadcast_arrays(lower, upper)
moment = (lower + upper) / 2
if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment

def logcdf(value, lower, upper):
"""
Compute the log of the cumulative distribution function for Uniform distribution
Expand All @@ -315,11 +323,6 @@ def logcdf(value, lower, upper):
),
)

def get_moment(value, size, lower, upper):
lower = at.full(size, lower, dtype=aesara.config.floatX)
upper = at.full(size, upper, dtype=aesara.config.floatX)
return (lower + upper) / 2


class FlatRV(RandomVariable):
name = "flat"
Expand Down Expand Up @@ -353,8 +356,8 @@ def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
return res

def get_moment(rv, size, *rv_inputs):
return at.zeros(size, dtype=aesara.config.floatX)
def get_moment(rv, size):
return at.zeros(size)

def logp(value):
"""
Expand Down Expand Up @@ -421,8 +424,8 @@ def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
return res

def get_moment(value_var, size, *rv_inputs):
return at.ones(size, dtype=aesara.config.floatX)
def get_moment(rv, size):
return at.ones(size)

def logp(value):
"""
Expand Down Expand Up @@ -540,6 +543,12 @@ def dist(cls, mu=0, sigma=None, tau=None, sd=None, no_assert=False, **kwargs):

return super().dist([mu, sigma], **kwargs)

def get_moment(rv, size, mu, sigma):
mu, _ = at.broadcast_arrays(mu, sigma)
if not rv_size_is_none(size):
mu = at.full(size, mu)
return mu

def logcdf(value, mu, sigma):
"""
Compute the log of the cumulative distribution function for Normal distribution
Expand All @@ -560,9 +569,6 @@ def logcdf(value, mu, sigma):
0 < sigma,
)

def get_moment(value_var, size, mu, sigma):
return at.full(size, mu, dtype=aesara.config.floatX)


class TruncatedNormalRV(RandomVariable):
name = "truncated_normal"
Expand Down Expand Up @@ -691,19 +697,35 @@ def dist(
assert_negative_support(sigma, "sigma", "TruncatedNormal")
assert_negative_support(tau, "tau", "TruncatedNormal")

# if lower is None and upper is None:
# initval = mu
# elif lower is None and upper is not None:
# initval = upper - 1.0
# elif lower is not None and upper is None:
# initval = lower + 1.0
# else:
# initval = (lower + upper) / 2

lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)
return super().dist([mu, sigma, lower, upper], **kwargs)

def get_moment(rv, size, mu, sigma, lower, upper):
mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper)
moment = at.switch(
at.isinf(lower),
at.switch(
at.isinf(upper),
# lower = -inf, upper = inf
mu,
# lower = -inf, upper = x
upper - 1,
),
at.switch(
at.isinf(upper),
# lower = x, upper = inf
lower + 1,
# lower = x, upper = x
(lower + upper) / 2,
),
)

if not rv_size_is_none(size):
moment = at.full(size, moment)

return moment

def logp(
value,
mu: Union[float, np.ndarray, TensorVariable],
Expand Down Expand Up @@ -828,6 +850,12 @@ def dist(cls, sigma=None, tau=None, sd=None, *args, **kwargs):

return super().dist([0.0, sigma], **kwargs)

def get_moment(rv, size, loc, sigma):
moment = loc + sigma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a doubt here.

Current implementation: loc + sigma
Wikipedia version: loc + \sqrt{\frac{2}{\pi}}sigma

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, we should probably be using that. Do you want to open a PR to fix it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, happily :)

Copy link
Member

@ricardoV94 ricardoV94 Nov 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit more involved, we should probably also use the truncated normal mean as the moment. We might even be able to simplify the switch statement, if the normal logcdf behaves well with +- infinity values: https://en.wikipedia.org/wiki/Truncated_normal_distribution

if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment

def logcdf(value, loc, sigma):
"""
Compute the log of the cumulative distribution function for HalfNormal distribution
Expand All @@ -850,9 +878,6 @@ def logcdf(value, loc, sigma):
0 < sigma,
)

def _distr_parameters_for_repr(self):
return ["sigma"]


class WaldRV(RandomVariable):
name = "wald"
Expand Down
13 changes: 6 additions & 7 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.logprob import _logcdf
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.math import sigmoid

__all__ = [
Expand Down Expand Up @@ -352,6 +353,11 @@ def dist(cls, p=None, logit_p=None, *args, **kwargs):
p = at.as_tensor_variable(floatX(p))
return super().dist([p], **kwargs)

def get_moment(rv, size, p):
if not rv_size_is_none(size):
p = at.full(size, p)
return at.switch(p < 0.5, 0, 1)

def logp(value, p):
r"""
Calculate log-probability of Bernoulli distribution at specified value.
Expand Down Expand Up @@ -402,13 +408,6 @@ def logcdf(value, p):
p <= 1,
)

def get_moment(value, size, p):
p = at.full(size, p)
return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value))

def _distr_parameters_for_repr(self):
return ["p"]


class DiscreteWeibullRV(RandomVariable):
name = "discrete_weibull"
Expand Down
3 changes: 1 addition & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Callable, Optional, Sequence

import aesara
import numpy as np

from aeppl.logprob import _logprob
from aesara.tensor.basic import as_tensor_variable
Expand Down Expand Up @@ -371,7 +370,7 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
for which the value is to be derived.
"""
size = rv.owner.inputs[1]
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]).astype(rv.dtype)


class Discrete(Distribution):
Expand Down
8 changes: 7 additions & 1 deletion pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np

from aesara.graph.basic import Variable
from aesara.graph.basic import Constant, Variable
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import change_rv_size, pandas_to_array
Expand All @@ -37,6 +37,7 @@
"get_broadcastable_dist_samples",
"broadcast_distribution_samples",
"broadcast_dist_samples_to",
"rv_size_is_none",
]


Expand Down Expand Up @@ -674,3 +675,8 @@ def maybe_resize(
)

return rv_out


def rv_size_is_none(size: Variable) -> bool:
"""Check wether an rv size is None (ie., at.Constant([]))"""
return isinstance(size, Constant) and size.data.size == 0
144 changes: 144 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import numpy as np
import pytest

from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform
from pymc.distributions import HalfNormal
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.initial_point import make_initial_point_fn
from pymc.model import Model


def test_rv_size_is_none():
rv = Normal.dist(0, 1, size=None)
assert rv_size_is_none(rv.owner.inputs[1])

rv = Normal.dist(0, 1, size=1)
assert not rv_size_is_none(rv.owner.inputs[1])

size = Bernoulli.dist(0.5)
rv = Normal.dist(0, 1, size=size)
assert not rv_size_is_none(rv.owner.inputs[1])

size = Normal.dist(0, 1).size
rv = Normal.dist(0, 1, size=size)
assert not rv_size_is_none(rv.owner.inputs[1])


def assert_moment_is_expected(model, expected):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
)
result = fn(0)["x"]
expected = np.asarray(expected)
try:
random_draw = model["x"].eval()
except NotImplementedError:
random_draw = result
assert result.shape == expected.shape == random_draw.shape
assert np.allclose(result, expected)


@pytest.mark.parametrize(
"size, expected",
[
(None, 0),
(5, np.zeros(5)),
((2, 5), np.zeros((2, 5))),
],
)
def test_flat_moment(size, expected):
with Model() as model:
Flat("x", size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"size, expected",
[
(None, 1),
(5, np.ones(5)),
((2, 5), np.ones((2, 5))),
],
)
def test_halfflat_moment(size, expected):
with Model() as model:
HalfFlat("x", size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"lower, upper, size, expected",
[
(-1, 1, None, 0),
(-1, 1, 5, np.zeros(5)),
(0, np.arange(1, 6), None, np.arange(1, 6) / 2),
(0, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6) / 2)),
],
)
def test_uniform_moment(lower, upper, size, expected):
with Model() as model:
Uniform("x", lower=lower, upper=upper, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"mu, sigma, size, expected",
[
(0, 1, None, 0),
(0, np.ones(5), None, np.zeros(5)),
(np.arange(5), 1, None, np.arange(5)),
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
],
)
def test_normal_moment(mu, sigma, size, expected):
with Model() as model:
Normal("x", mu=mu, sigma=sigma, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"sigma, size, expected",
[
(1, None, 1),
(1, 5, np.ones(5)),
(np.arange(5), None, np.arange(5)),
(np.arange(5), (2, 5), np.full((2, 5), np.arange(5))),
],
)
def test_halfnormal_moment(sigma, size, expected):
with Model() as model:
HalfNormal("x", sigma=sigma, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytest.mark.parametrize(
"mu, sigma, lower, upper, size, expected",
[
(0.9, 1, -1, 1, None, 0),
(0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)),
(np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)),
(1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)),
],
)
def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected):
with Model() as model:
TruncatedNormal("x", mu=mu, sigma=sigma, lower=lower, upper=upper, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"p, size, expected",
[
(0.3, None, 0),
(0.9, 5, np.ones(5)),
(np.linspace(0, 1, 4), None, [0, 0, 1, 1]),
(np.linspace(0, 1, 4), (2, 4), np.full((2, 4), [0, 0, 1, 1])),
],
)
def test_bernoulli_moment(p, size, expected):
with Model() as model:
Bernoulli("x", p=p, size=size)
assert_moment_is_expected(model, expected)