Skip to content

Update distribution guide #6166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
(implementing-a-distribution)=
# Implementing a Distribution
# Implementing a RandomVariable Distribution

This guide provides an overview on how to implement a distribution for PyMC version `>=4.0.0`.
It is designed for developers who wish to add a new distribution to the library.
Users will not be aware of all this complexity and should instead make use of helper methods such as `~pymc.DensityDist`.

PyMC {class}`~pymc.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `moment` methods as well as other initialization and validation helpers.
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.
Most notably `shape/dims/observed` kwargs, alternative parametrizations, and default `transform`.

Here is a summary check-list of the steps needed to implement a new distribution.
Each section will be expanded below:
Expand Down Expand Up @@ -88,11 +88,11 @@ blah = BlahRV()
Some important things to keep in mind:

1. Everything inside the `rng_fn` method is pure Python code (as are the inputs) and should not make use of other `Aesara` symbolic ops. The random method should make use of the `rng` which is a NumPy {class}`~numpy.random.RandomState`, so that samples are reproducible.
1. Non-default `RandomVariable` dimensions will end up in the `rng_fn` via the `size` kwarg. The `rng_fn` will have to take this into consideration for correct output. `size` is the specification used by NumPy and SciPy and works like PyMC `shape` for univariate distributions, but is different for multivariate distributions. For multivariate distributions the __`size` excludes the `ndim_supp` support dimensions__, whereas the __`shape` of the resulting `TensorVariabe` or `ndarray` includes the support dimensions__. This [discussion](https://github.com/numpy/numpy/issues/17669) may be helpful to get more context.
1. Non-default `RandomVariable` dimensions will end up in the `rng_fn` via the `size` kwarg. The `rng_fn` will have to take this into consideration for correct output. `size` is the specification used by NumPy and SciPy and works like PyMC `shape` for univariate distributions, but is different for multivariate distributions. For multivariate distributions the __`size` excludes the `ndim_supp` support dimensions__, whereas the __`shape` of the resulting `TensorVariabe` or `ndarray` includes the support dimensions__. For more context check {doc}`The dimensionality notebook <pymc:dimensionality=>`.
1. `Aesara` tries to infer the output shape of the `RandomVariable` (given a user-specified size) by introspection of the `ndim_supp` and `ndim_params` attributes. However, the default method may not work for more complex distributions. In that case, custom `_supp_shape_from_params` (and less probably, `_infer_shape`) should also be implemented in the new `RandomVariable` class. One simple example is seen in the {class}`~pymc.DirichletMultinomialRV` where it was necessary to specify the `rep_param_idx` so that the `default_supp_shape_from_params` helper method can do its job. In more complex cases, it may not suffice to use this default helper. This could happen for instance if the argument values determined the support shape of the distribution, as happens in the `~pymc.distributions.multivarite._LKJCholeskyCovRV`.
1. It's okay to use the `rng_fn` `classmethods` of other Aesara and PyMC `RandomVariables` inside the new `rng_fn`. For example if you are implementing a negative HalfNormal `RandomVariable`, your `rng_fn` can simply return `- halfnormal.rng_fn(rng, scale, size)`.

*Note: In addition to `size`, the PyMC API also provides `shape` and `dims` as alternatives to define a distribution dimensionality, but this is taken care of by {class}`~pymc.Distribution`, and should not require any extra changes.*
*Note: In addition to `size`, the PyMC API also provides `shape`, `dims` and `observed` as alternatives to define a distribution dimensionality, but this is taken care of by {class}`~pymc.Distribution`, and should not require any extra changes.*

For a quick test that your new `RandomVariable` `Op` is working, you can call the `Op` with the necessary parameters and then call `eval()` on the returned object:

Expand Down Expand Up @@ -129,9 +129,11 @@ Here is how the example continues:

```python

import aesara.tensor as at
from pymc.aesaraf import floatX, intX
from pymc.distributions.continuous import PositiveContinuous
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.shape_utils import rv_size_is_none


# Subclassing `PositiveContinuous` will dispatch a default `log` transformation
Expand Down Expand Up @@ -231,15 +233,16 @@ pm.logcdf(blah, [-0.5, 1.5]).eval()

## 3. Adding tests for the new `RandomVariable`

Tests for new `RandomVariables` are mostly located in `pymc/tests/test_distributions_random.py`.
Most tests can be accommodated by the default `BaseTestDistribution` class, which provides default tests for checking:
Tests for new `RandomVariables` are mostly located in `pymc/tests/distributions/test_*.py`.
Most tests can be accommodated by the default `BaseTestDistributionRandom` class, which provides default tests for checking:
1. Expected inputs are passed to the `rv_op` by the `dist` `classmethod`, via `check_pymc_params_match_rv_op`
1. Expected (exact) draws are being returned, via `check_pymc_draws_match_reference`
1. Shape variable inference is correct, via `check_rv_size`

```python
from pymc.tests.distributions.util import BaseTestDistributionRandom, seeded_scipy_distribution_builder

class TestBlah(BaseTestDistribution):
class TestBlah(BaseTestDistributionRandom):

pymc_dist = pm.Blah
# Parameters with which to test the blah pymc Distribution
Expand All @@ -266,7 +269,7 @@ For instance, if it's just the inverse, testing with `1.0` is not very informati

```python

class TestBlahAltParam2(BaseTestDistribution):
class TestBlahAltParam2(BaseTestDistributionRandom):

pymc_dist = pm.Blah
# param2 is equivalent to 1 / alt_param2
Expand All @@ -276,7 +279,7 @@ class TestBlahAltParam2(BaseTestDistribution):

```

Custom tests can also be added to the class as is done for the {class}`~pymc.tests.test_random.TestFlat`.
Custom tests can also be added to the class as is done for the {class}`~pymc.tests.distributions.test_continuous.TestFlat`.

### Note on `check_rv_size` test:

Expand All @@ -289,37 +292,36 @@ tests_to_run = ["check_rv_size"]
```

This is usually needed for Multivariate distributions.
You can see an example in {class}`~pymc.test.test_random.TestDirichlet`.
You can see an example in {class}`~pymc.tests.distributions.test_multivariate.TestDirichlet`.

### Notes on `check_pymcs_draws_match_reference` test

The `check_pymcs_draws_match_reference` is a very simple test for the equality of draws from the `RandomVariable` and the exact same python function, given the same inputs and random seed.
A small number (`size=15`) is checked. This is not supposed to be a test for the correctness of the random number generator.
The latter kind of test (if warranted) can be performed with the aid of `pymc_random` and `pymc_random_discrete` methods in the same test file, which will perform an expensive statistical comparison between the `RandomVariable.rng_fn` and a reference Python function.
The latter kind of test (if warranted) can be performed with the aid of `pymc_random` and `pymc_random_discrete` methods, which will perform an expensive statistical comparison between the `RandomVariable.rng_fn` and a reference Python function.
This kind of test only makes sense if there is a good independent generator reference (i.e., not just the same composition of NumPy / SciPy calls that is done inside `rng_fn`).

Finally, when your `rng_fn` is doing something more than just calling a NumPy or SciPy method, you will need to set up an equivalent seeded function with which to compare for the exact draws (instead of relying on `seeded_[scipy|numpy]_distribution_builder`).
You can find an example in {class}`~pymc.tests.test_distributions_random.TestWeibull`, whose `rng_fn` returns `beta * np.random.weibull(alpha, size=size)`.
You can find an example in {class}`~pymc.tests.distributions.test_continuous.TestWeibull`, whose `rng_fn` returns `beta * np.random.weibull(alpha, size=size)`.


## 4. Adding tests for the `logp` / `logcdf` methods

Tests for the `logp` and `logcdf` methods are contained in `pymc/tests/test_distributions.py`, and most make use of the `TestMatchesScipy` class, which provides `check_logp`, `check_logcdf`, and
`check_selfconsistency_discrete_logcdf` standard methods.
These will suffice for most distributions.
Tests for the `logp` and `logcdf` mostly make use of the helpers `check_logp`, `check_logcdf`, and
`check_selfconsistency_discrete_logcdf` implemented in `~pymc.tests.distributions.util`

```python

from pymc.tests.distributions.util import check_logp, check_logcdf, Domain
from pymc.tests.helpers import select_by_precision

R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])

...

def test_blah(self):

self.check_logp(
def test_blah():

check_logp(
pymc_dist=pm.Blah,
# Domain of the distribution values
domain=R,
Expand All @@ -333,7 +335,7 @@ def test_blah(self):
n_samples=100,
)

self.check_logcdf(
check_logcdf(
pymc_dist=pm.Blah,
domain=R,
paramdomains={"mu": R, "sigma": Rplus},
Expand Down Expand Up @@ -370,15 +372,17 @@ def test_blah_logcdf(self):

## 5. Adding tests for the `moment` method

Tests for the `moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
Tests for the `moment` make use of the function `assert_moment_is_expected`
which checks if:
1. Moments return the `expected` values
1. Moments have the expected size and shape
1. Moments have a finite logp

```python

import pytest
from pymc.distributions import Blah
from pymc.tests.distributions.util import assert_moment_is_expected

@pytest.mark.parametrize(
"param1, param2, size, expected",
Expand Down