Skip to content

Commit 825fb19

Browse files
committed
Added moments for gumbel, triangular and logitnormal distributions
1 parent 8d1708a commit 825fb19

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

pymc/distributions/continuous.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def polyagamma_cdf(*args, **kwargs):
8686
)
8787
from pymc.distributions.distribution import Continuous
8888
from pymc.distributions.shape_utils import rv_size_is_none
89-
from pymc.math import logdiffexp, logit
89+
from pymc.math import invlogit, logdiffexp, logit
9090
from pymc.util import UNSET
9191

9292
__all__ = [
@@ -3101,6 +3101,12 @@ def dist(cls, lower=0, upper=1, c=0.5, *args, **kwargs):
31013101

31023102
return super().dist([lower, c, upper], *args, **kwargs)
31033103

3104+
def get_moment(rv, size, lower, c, upper):
3105+
mean = (lower + upper + c) / 3
3106+
if not rv_size_is_none(size):
3107+
mean = at.full(size, mean)
3108+
return mean
3109+
31043110
def logcdf(value, lower, c, upper):
31053111
"""
31063112
Compute the log of the cumulative distribution function for Triangular distribution
@@ -3198,6 +3204,12 @@ def dist(
31983204

31993205
return super().dist([mu, beta], **kwargs)
32003206

3207+
def get_moment(rv, size, mu, beta):
3208+
mean = mu + beta * np.euler_gamma
3209+
if not rv_size_is_none(size):
3210+
mean = at.full(size, mean)
3211+
return mean
3212+
32013213
def _distr_parameters_for_repr(self):
32023214
return ["mu", "beta"]
32033215

@@ -3501,6 +3513,12 @@ def dist(cls, mu=0, sigma=None, tau=None, sd=None, **kwargs):
35013513

35023514
return super().dist([mu, sigma], **kwargs)
35033515

3516+
def get_moment(rv, size, mu, sigma):
3517+
median, _ = at.broadcast_arrays(invlogit(mu), sigma)
3518+
if not rv_size_is_none(size):
3519+
median = at.full(size, median)
3520+
return median
3521+
35043522
def logp(value, mu, sigma):
35053523
"""
35063524
Calculate log-probability of LogitNormal distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Flat,
1818
Gamma,
1919
Geometric,
20+
Gumbel,
2021
HalfCauchy,
2122
HalfFlat,
2223
HalfNormal,
@@ -25,12 +26,14 @@
2526
Kumaraswamy,
2627
Laplace,
2728
Logistic,
29+
LogitNormal,
2830
LogNormal,
2931
NegativeBinomial,
3032
Normal,
3133
Pareto,
3234
Poisson,
3335
StudentT,
36+
Triangular,
3437
TruncatedNormal,
3538
Uniform,
3639
Wald,
@@ -40,6 +43,7 @@
4043
)
4144
from pymc.distributions.shape_utils import rv_size_is_none
4245
from pymc.initial_point import make_initial_point_fn
46+
from pymc.math import invlogit
4347
from pymc.model import Model
4448

4549

@@ -650,3 +654,67 @@ def test_dirichlet_moment(a, size, expected):
650654
with Model() as model:
651655
Dirichlet("x", a=a, size=size)
652656
assert_moment_is_expected(model, expected)
657+
658+
659+
@pytest.mark.parametrize(
660+
"mu, beta, size, expected",
661+
[
662+
(0, 2, None, 2 * np.euler_gamma),
663+
(1, np.arange(1, 4), None, 1 + np.arange(1, 4) * np.euler_gamma),
664+
(np.arange(5), 2, None, np.arange(5) + 2 * np.euler_gamma),
665+
(1, 2, 5, np.full(5, 1 + 2 * np.euler_gamma)),
666+
(
667+
np.arange(5),
668+
np.arange(1, 6),
669+
(2, 5),
670+
np.full((2, 5), np.arange(5) + np.arange(1, 6) * np.euler_gamma)
671+
)
672+
],
673+
)
674+
def test_gumbel_moment(mu, beta, size, expected):
675+
with Model() as model:
676+
Gumbel("x", mu=mu, beta=beta, size=size)
677+
assert_moment_is_expected(model, expected)
678+
679+
680+
@pytest.mark.parametrize(
681+
"c, lower, upper, size, expected",
682+
[
683+
(1, 0, 5, None, 2),
684+
(3, np.arange(-3, 6, 3), np.arange(3, 12, 3), None, np.array([1, 3, 5])),
685+
(np.arange(-3, 6, 3), -3, 3, None, np.array([-1, 0, 1])),
686+
(3, -3, 6, 5, np.full(5, 2)),
687+
(
688+
np.arange(-3, 6, 3),
689+
np.arange(-9, -2, 3),
690+
np.arange(3, 10, 3),
691+
(2, 3),
692+
np.full((2, 3), np.array([-3, 0, 3]))
693+
)
694+
],
695+
)
696+
def test_triangular_moment(c, lower, upper, size, expected):
697+
with Model() as model:
698+
Triangular("x", c=c, lower=lower, upper=upper, size=size)
699+
assert_moment_is_expected(model, expected)
700+
701+
702+
@pytest.mark.parametrize(
703+
"mu, sigma, size, expected",
704+
[
705+
(1, 2, None, special.expit(1)),
706+
(0, np.arange(1, 5), None, special.expit(np.zeros(4))),
707+
(np.arange(4), 1, None, special.expit(np.arange(4))),
708+
(1, 5, 4, special.expit(np.ones(4))),
709+
(
710+
np.arange(4),
711+
np.arange(1, 5),
712+
(2, 4),
713+
np.full((2, 4), special.expit(np.arange(4)))
714+
)
715+
],
716+
)
717+
def test_logitnormal_moment(mu, sigma, size, expected):
718+
with Model() as model:
719+
LogitNormal("x", mu=mu, sigma=sigma, size=size)
720+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)