Skip to content

Commit 45e33bc

Browse files
lucianopazricardoV94
authored andcommitted
Specialize get_moment for multivariate density dists
1 parent f36c433 commit 45e33bc

File tree

3 files changed

+88
-45
lines changed

3 files changed

+88
-45
lines changed

pymc/distributions/distribution.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -551,12 +551,17 @@ def random(mu, rng=None, size=None):
551551
if logcdf is None:
552552
logcdf = default_not_implemented(name, "logcdf")
553553

554+
if get_moment is None:
555+
get_moment = functools.partial(
556+
default_get_moment,
557+
rv_name=name,
558+
has_fallback=random is not None,
559+
ndim_supp=ndim_supp,
560+
)
561+
554562
if random is None:
555563
random = default_not_implemented(name, "random")
556564

557-
if get_moment is None:
558-
get_moment = default_not_implemented(name, "get_moment")
559-
560565
rv_op = type(
561566
f"DensityDist_{name}",
562567
(DensityDistRV,),
@@ -617,5 +622,14 @@ def func(*args, **kwargs):
617622
return func
618623

619624

620-
def default_get_moment(rv, size, *rv_inputs):
621-
return at.zeros(size, dtype=rv.dtype)
625+
def default_get_moment(rv, size, *rv_inputs, rv_name=None, has_fallback=False, ndim_supp=0):
626+
if ndim_supp == 0:
627+
return at.zeros(size, dtype=rv.dtype)
628+
elif has_fallback:
629+
return at.zeros_like(rv)
630+
else:
631+
raise TypeError(
632+
"Cannot safely infer the size of a multivariate random variable's moment. "
633+
f"Please provide a get_moment function when instantiating the {rv_name} "
634+
"random variable."
635+
)

pymc/tests/test_distributions_moments.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import aesara
12
import numpy as np
23
import pytest
34

45
from aesara import tensor as at
56
from scipy import special
67

8+
import pymc as pm
9+
710
from pymc.distributions import (
811
AsymmetricLaplace,
912
Bernoulli,
@@ -49,8 +52,9 @@
4952
ZeroInflatedBinomial,
5053
ZeroInflatedPoisson,
5154
)
55+
from pymc.distributions.distribution import get_moment
5256
from pymc.distributions.multivariate import MvNormal
53-
from pymc.distributions.shape_utils import rv_size_is_none
57+
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
5458
from pymc.initial_point import make_initial_point_fn
5559
from pymc.model import Model
5660

@@ -911,9 +915,72 @@ def test_rice_moment(nu, sigma, size, expected):
911915
("custom_moment", (2, 5), np.full((2, 5), 5)),
912916
],
913917
)
914-
def test_density_dist_moment(get_moment, size, expected):
918+
def test_density_dist_default_moment_univariate(get_moment, size, expected):
915919
if get_moment == "custom_moment":
916920
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
917921
with Model() as model:
918922
DensityDist("x", get_moment=get_moment, size=size)
919923
assert_moment_is_expected(model, expected)
924+
925+
926+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
927+
def test_density_dist_custom_moment_univariate(size):
928+
def moment(rv, size, mu):
929+
return (at.ones(size) * mu).astype(rv.dtype)
930+
931+
mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX)
932+
with pm.Model():
933+
mu = pm.Normal("mu")
934+
a = pm.DensityDist("a", mu, get_moment=moment, size=size)
935+
evaled_moment = get_moment(a).eval({mu: mu_val})
936+
assert evaled_moment.shape == to_tuple(size)
937+
assert np.all(evaled_moment == mu_val)
938+
939+
940+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
941+
def test_density_dist_custom_moment_multivariate(size):
942+
def moment(rv, size, mu):
943+
return (at.ones(size)[..., None] * mu).astype(rv.dtype)
944+
945+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
946+
with pm.Model():
947+
mu = pm.Normal("mu", size=5)
948+
a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size)
949+
evaled_moment = get_moment(a).eval({mu: mu_val})
950+
assert evaled_moment.shape == to_tuple(size) + (5,)
951+
assert np.all(evaled_moment == mu_val)
952+
953+
954+
@pytest.mark.parametrize(
955+
"with_random, size",
956+
[
957+
(True, ()),
958+
(True, (2,)),
959+
(True, (3, 2)),
960+
(False, ()),
961+
(False, (2,)),
962+
],
963+
)
964+
def test_density_dist_default_moment_multivariate(with_random, size):
965+
def _random(mu, rng=None, size=None):
966+
return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape)
967+
968+
if with_random:
969+
random = _random
970+
else:
971+
random = None
972+
973+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
974+
with pm.Model():
975+
mu = pm.Normal("mu", size=5)
976+
a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
977+
if with_random:
978+
evaled_moment = get_moment(a).eval({mu: mu_val})
979+
assert evaled_moment.shape == to_tuple(size) + (5,)
980+
assert np.all(evaled_moment == 0)
981+
else:
982+
with pytest.raises(
983+
TypeError,
984+
match="Cannot safely infer the size of a multivariate random variable's moment.",
985+
):
986+
evaled_moment = get_moment(a).eval({mu: mu_val})

pymc/tests/test_moment.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)