Skip to content

Commit b243827

Browse files
Add test_all_distributions_have_moments
Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
1 parent 0c90e82 commit b243827

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

pymc/tests/test_distributions_moments.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import pymc as pm
1010

11-
from pymc import Simulator
1211
from pymc.distributions import (
1312
AsymmetricLaplace,
1413
Bernoulli,
@@ -50,6 +49,7 @@
5049
Poisson,
5150
PolyaGamma,
5251
Rice,
52+
Simulator,
5353
SkewNormal,
5454
StudentT,
5555
Triangular,
@@ -62,13 +62,73 @@
6262
ZeroInflatedNegativeBinomial,
6363
ZeroInflatedPoisson,
6464
)
65-
from pymc.distributions.distribution import get_moment
65+
from pymc.distributions.distribution import _get_moment, get_moment
6666
from pymc.distributions.multivariate import MvNormal
6767
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
6868
from pymc.initial_point import make_initial_point_fn
6969
from pymc.model import Model
7070

7171

72+
def test_all_distributions_have_moments():
73+
import pymc.distributions as dist_module
74+
75+
from pymc.distributions.distribution import DistributionMeta
76+
77+
dists = (getattr(dist_module, dist) for dist in dist_module.__all__)
78+
dists = (dist for dist in dists if isinstance(dist, DistributionMeta))
79+
missing_moments = {
80+
dist for dist in dists if type(getattr(dist, "rv_op", None)) not in _get_moment.registry
81+
}
82+
83+
# Ignore super classes
84+
missing_moments -= {
85+
dist_module.Distribution,
86+
dist_module.Discrete,
87+
dist_module.Continuous,
88+
dist_module.NoDistribution,
89+
dist_module.DensityDist,
90+
dist_module.simulator.Simulator,
91+
}
92+
93+
# Distributions that have not been refactored for V4 yet
94+
not_implemented = {
95+
dist_module.multivariate.LKJCorr,
96+
dist_module.mixture.Mixture,
97+
dist_module.mixture.MixtureSameFamily,
98+
dist_module.mixture.NormalMixture,
99+
dist_module.timeseries.AR,
100+
dist_module.timeseries.AR1,
101+
dist_module.timeseries.GARCH11,
102+
dist_module.timeseries.GaussianRandomWalk,
103+
dist_module.timeseries.MvGaussianRandomWalk,
104+
dist_module.timeseries.MvStudentTRandomWalk,
105+
}
106+
107+
# Distributions that have been refactored but don't yet have moments
108+
not_implemented |= {
109+
dist_module.discrete.DiscreteWeibull,
110+
dist_module.multivariate.CAR,
111+
dist_module.multivariate.DirichletMultinomial,
112+
dist_module.multivariate.KroneckerNormal,
113+
dist_module.multivariate.Wishart,
114+
}
115+
116+
unexpected_implemented = not_implemented - missing_moments
117+
if unexpected_implemented:
118+
raise Exception(
119+
f"Distributions {unexpected_implemented} have a `get_moment` implemented. "
120+
"This test must be updated to expect this."
121+
)
122+
123+
unexpected_not_implemented = missing_moments - not_implemented
124+
if unexpected_not_implemented:
125+
raise NotImplementedError(
126+
f"Unexpected by this test, distributions {unexpected_not_implemented} do "
127+
"not have a `get_moment` implementation. Either add a moment or filter "
128+
"these distributions in this test."
129+
)
130+
131+
72132
def test_rv_size_is_none():
73133
rv = Normal.dist(0, 1, size=None)
74134
assert rv_size_is_none(rv.owner.inputs[1])

0 commit comments

Comments
 (0)