From f336aa768d24ccc7c3728dd2f9267d6c1bd808fb Mon Sep 17 00:00:00 2001 From: kc611 Date: Sun, 17 Jan 2021 22:55:27 +0530 Subject: [PATCH] Fixed Dirichlet.random returning output in wrong shapes --- RELEASE-NOTES.md | 1 + pymc3/distributions/multivariate.py | 27 ++++-------------------- pymc3/tests/test_distributions_random.py | 18 ++++++++++++++++ 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 2ccf1d04af..a9fb04990a 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -35,6 +35,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang - Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393)) - Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407)) - Issue UserWarning when doing prior or posterior predictive sampling with models containing Potential factors (see [#4419](https://github.com/pymc-devs/pymc3/pull/4419)) +- Dirichlet distribution's `random` method is now optimized and gives outputs in correct shape (see [#4416](https://github.com/pymc-devs/pymc3/pull/4407)) ## PyMC3 3.10.0 (7 December 2020) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index e5c73179d9..a136387912 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -471,33 +471,11 @@ def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs): super().__init__(transform=transform, *args, **kwargs) - self.size_prefix = tuple(self.shape[:-1]) self.a = a = tt.as_tensor_variable(a) self.mean = a / tt.sum(a) self.mode = tt.switch(tt.all(a > 1), (a - 1) / tt.sum(a - 1), np.nan) - def _random(self, a, size=None): - gen = stats.dirichlet.rvs - shape = tuple(np.atleast_1d(self.shape)) - if size[-len(shape) :] == shape: - real_size = size[: -len(shape)] - else: - real_size = size - if self.size_prefix: - if real_size and real_size[0] == 1: - real_size = real_size[1:] + self.size_prefix - else: - real_size = real_size + self.size_prefix - - if a.ndim == 1: - samples = gen(alpha=a, size=real_size) - else: - unrolled = a.reshape((np.prod(a.shape[:-1]), a.shape[-1])) - samples = np.array([gen(alpha=aa, size=1) for aa in unrolled]) - samples = samples.reshape(a.shape) - return samples - def random(self, point=None, size=None): """ Draw random values from Dirichlet distribution. @@ -516,7 +494,10 @@ def random(self, point=None, size=None): array """ a = draw_values([self.a], point=point, size=size)[0] - samples = generate_samples(self._random, a=a, dist_shape=self.shape, size=size) + output_shape = to_tuple(size) + to_tuple(self.shape) + a = broadcast_dist_samples_to(to_shape=output_shape, samples=[a], size=size)[0] + samples = stats.gamma.rvs(a=a, size=output_shape) + samples = samples / samples.sum(-1, keepdims=True) return samples def logp(self, value): diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 2513cd3029..a56f3f3b7b 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -542,6 +542,24 @@ def test_probability_vector_shape(self): assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7) +class TestDirichlet(SeededTest): + @pytest.mark.parametrize( + "shape, size", + [ + ((2), (1)), + ((2), (2)), + ((2, 2), (2, 100)), + ((3, 4), (3, 4)), + ((3, 4), (3, 4, 100)), + ((3, 4), (100)), + ((3, 4), (1)), + ], + ) + def test_dirichlet_random_shape(self, shape, size): + out_shape = to_tuple(size) + to_tuple(shape) + assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape + + class TestScalarParameterSamples(SeededTest): def test_bounded(self): # A bit crude...