Skip to content

Commit f36c433

Browse files
lucianopazricardoV94
authored andcommitted
Add DensityDist moment
1 parent c0c5a80 commit f36c433

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

pymc/distributions/distribution.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import aesara
2525

2626
from aeppl.logprob import _logcdf, _logprob
27+
from aesara import tensor as at
2728
from aesara.tensor.basic import as_tensor_variable
2829
from aesara.tensor.random.op import RandomVariable
2930
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -472,9 +473,9 @@ def __new__(
472473
as the first argument ``rv``. ``size`` is the random variable's size implied
473474
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
474475
``rv_inputs`` is the sequence of the distribution parameters, in the same order
475-
as they were supplied when the DensityDist was created. If ``None``, a
476-
``NotImplemented`` error will be raised when trying to draw random samples from
477-
the distribution's prior or posterior predictive.
476+
as they were supplied when the DensityDist was created. If ``None``, a default
477+
``get_moment`` function will be assigned that will always return 0, or an array
478+
of zeros.
478479
ndim_supp : int
479480
The number of dimensions in the support of the distribution. Defaults to assuming
480481
a scalar distribution, i.e. ``ndim_supp = 0``.
@@ -614,3 +615,7 @@ def func(*args, **kwargs):
614615
raise NotImplementedError(message)
615616

616617
return func
618+
619+
620+
def default_get_moment(rv, size, *rv_inputs):
621+
return at.zeros(size, dtype=rv.dtype)

pymc/tests/test_distributions_moments.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from aesara import tensor as at
45
from scipy import special
56

67
from pymc.distributions import (
@@ -13,6 +14,7 @@
1314
Cauchy,
1415
ChiSquared,
1516
Constant,
17+
DensityDist,
1618
Dirichlet,
1719
DiscreteUniform,
1820
ExGaussian,
@@ -898,3 +900,20 @@ def test_rice_moment(nu, sigma, size, expected):
898900
with Model() as model:
899901
Rice("x", nu=nu, sigma=sigma, size=size)
900902
assert_moment_is_expected(model, expected)
903+
904+
905+
@pytest.mark.parametrize(
906+
"get_moment, size, expected",
907+
[
908+
(None, None, 0.0),
909+
(None, 5, np.zeros(5)),
910+
("custom_moment", None, 5),
911+
("custom_moment", (2, 5), np.full((2, 5), 5)),
912+
],
913+
)
914+
def test_density_dist_moment(get_moment, size, expected):
915+
if get_moment == "custom_moment":
916+
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
917+
with Model() as model:
918+
DensityDist("x", get_moment=get_moment, size=size)
919+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)