|
| 1 | +import aesara |
1 | 2 | import numpy as np
|
2 | 3 | import pytest
|
3 | 4 |
|
4 | 5 | from aesara import tensor as at
|
5 | 6 | from scipy import special
|
6 | 7 |
|
| 8 | +import pymc as pm |
| 9 | + |
7 | 10 | from pymc.distributions import (
|
8 | 11 | AsymmetricLaplace,
|
9 | 12 | Bernoulli,
|
|
49 | 52 | ZeroInflatedBinomial,
|
50 | 53 | ZeroInflatedPoisson,
|
51 | 54 | )
|
| 55 | +from pymc.distributions.distribution import get_moment |
52 | 56 | from pymc.distributions.multivariate import MvNormal
|
53 |
| -from pymc.distributions.shape_utils import rv_size_is_none |
| 57 | +from pymc.distributions.shape_utils import rv_size_is_none, to_tuple |
54 | 58 | from pymc.initial_point import make_initial_point_fn
|
55 | 59 | from pymc.model import Model
|
56 | 60 |
|
@@ -911,9 +915,72 @@ def test_rice_moment(nu, sigma, size, expected):
|
911 | 915 | ("custom_moment", (2, 5), np.full((2, 5), 5)),
|
912 | 916 | ],
|
913 | 917 | )
|
914 |
| -def test_density_dist_moment(get_moment, size, expected): |
| 918 | +def test_density_dist_default_moment_univariate(get_moment, size, expected): |
915 | 919 | if get_moment == "custom_moment":
|
916 | 920 | get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
|
917 | 921 | with Model() as model:
|
918 | 922 | DensityDist("x", get_moment=get_moment, size=size)
|
919 | 923 | assert_moment_is_expected(model, expected)
|
| 924 | + |
| 925 | + |
| 926 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 927 | +def test_density_dist_custom_moment_univariate(size): |
| 928 | + def moment(rv, size, mu): |
| 929 | + return (at.ones(size) * mu).astype(rv.dtype) |
| 930 | + |
| 931 | + mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX) |
| 932 | + with pm.Model(): |
| 933 | + mu = pm.Normal("mu") |
| 934 | + a = pm.DensityDist("a", mu, get_moment=moment, size=size) |
| 935 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 936 | + assert evaled_moment.shape == to_tuple(size) |
| 937 | + assert np.all(evaled_moment == mu_val) |
| 938 | + |
| 939 | + |
| 940 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 941 | +def test_density_dist_custom_moment_multivariate(size): |
| 942 | + def moment(rv, size, mu): |
| 943 | + return (at.ones(size)[..., None] * mu).astype(rv.dtype) |
| 944 | + |
| 945 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 946 | + with pm.Model(): |
| 947 | + mu = pm.Normal("mu", size=5) |
| 948 | + a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size) |
| 949 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 950 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 951 | + assert np.all(evaled_moment == mu_val) |
| 952 | + |
| 953 | + |
| 954 | +@pytest.mark.parametrize( |
| 955 | + "with_random, size", |
| 956 | + [ |
| 957 | + (True, ()), |
| 958 | + (True, (2,)), |
| 959 | + (True, (3, 2)), |
| 960 | + (False, ()), |
| 961 | + (False, (2,)), |
| 962 | + ], |
| 963 | +) |
| 964 | +def test_density_dist_default_moment_multivariate(with_random, size): |
| 965 | + def _random(mu, rng=None, size=None): |
| 966 | + return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) |
| 967 | + |
| 968 | + if with_random: |
| 969 | + random = _random |
| 970 | + else: |
| 971 | + random = None |
| 972 | + |
| 973 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 974 | + with pm.Model(): |
| 975 | + mu = pm.Normal("mu", size=5) |
| 976 | + a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) |
| 977 | + if with_random: |
| 978 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 979 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 980 | + assert np.all(evaled_moment == 0) |
| 981 | + else: |
| 982 | + with pytest.raises( |
| 983 | + TypeError, |
| 984 | + match="Cannot safely infer the size of a multivariate random variable's moment.", |
| 985 | + ): |
| 986 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
0 commit comments