Skip to content

Commit 0c90e82

Browse files
amyoshinoAdriano Yoshinomichaelosthege
authored
Document get_moment (#5244)
Closes #5210 Co-authored-by: Adriano Yoshino <adrianoyoshino@adrianos-mbp.lan> Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
1 parent f369137 commit 0c90e82

File tree

1 file changed

+61
-13
lines changed

1 file changed

+61
-13
lines changed

docs/source/contributing/developer_guide_implementing_distribution.md

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ This guide provides an overview on how to implement a distribution for version 4
44
It is designed for developers who wish to add a new distribution to the library.
55
Users will not be aware of all this complexity and should instead make use of helper methods such as (TODO).
66

7-
PyMC {class}`~pymc.distributions.Distribution` build on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implement `logp` and `logcdf` methods as well as other initialization and validation helpers, most notably `shape/dims`, alternative parametrizations, and default `transforms`.
7+
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers.
8+
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.
89

910
Here is a summary check-list of the steps needed to implement a new distribution.
1011
Each section will be expanded below:
1112

1213
1. Creating a new `RandomVariable` `Op`
1314
1. Implementing the corresponding `Distribution` class
1415
1. Adding tests for the new `RandomVariable`
15-
1. Adding tests for the `logp` / `logcdf` methods
16+
1. Adding tests for `logp` / `logcdf` and `get_moment` methods
1617
1. Documenting the new `Distribution`.
1718

1819
This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions.
@@ -118,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in
118119
PyMC 4.x works in a very {term}`functional <Functional Programming>` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together.
119120
In practice, they take care of:
120121

121-
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding logp and logcdf methods.
122+
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods.
122123
1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer.
123124
1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables.
124125
1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of.
@@ -153,6 +154,14 @@ class Blah(PositiveContinuous):
153154
# the rv_op needs in order to be instantiated
154155
return super().dist([param1, param2], **kwargs)
155156

157+
# get_moment returns a symbolic expression for the stable moment from which to start sampling
158+
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`
159+
def get_moment(rv, size, param1, param2):
160+
moment, _ = at.broadcast_arrays(param1, param2)
161+
if not rv_size_is_none(size):
162+
moment = at.full(size, moment)
163+
return moment
164+
156165
# Logp returns a symbolic expression for the logp evaluation of the variable
157166
# given the `value` of the variable and the parameters `param1` ... `paramN`
158167
def logp(value, param1, param2):
@@ -189,27 +198,34 @@ Some notes:
189198
overriding `__new__`.
190199
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
191200
1. The `logcdf` method is not a requirement, but it's a nice plus!
201+
1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
202+
1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
192203

193204
For a quick check that things are working you can try the following:
194205

195206
```python
196207

197208
import pymc as pm
209+
from pymc.distributions.distribution import get_moment
198210

199-
# pm.blah = pm.Uniform in this example
200-
blah = pm.Blah.dist([0, 0], [1, 2])
211+
# pm.blah = pm.Normal in this example
212+
blah = pm.blah.dist(mu = 0, sigma = 1)
201213

202214
# Test that the returned blah_op is still working fine
203215
blah.eval()
204-
# array([0.62778803, 1.95165513])
216+
# array(-1.01397228)
217+
218+
# Test the get_moment method
219+
get_moment(blah).eval()
220+
# array(0.)
205221

206-
# Test the logp
207-
pm.logp(blah, [1.5, 1.5]).eval()
208-
# array([ -inf, -0.69314718])
222+
# Test the logp method
223+
pm.logp(blah, [-0.5, 1.5]).eval()
224+
# array([-1.04393853, -2.04393853])
209225

210-
# Test the logcdf
211-
pm.logcdf(blah, [1.5, 1.5]).eval()
212-
# array([ 0. , -0.28768207])
226+
# Test the logcdf method
227+
pm.logcdf(blah, [-0.5, 1.5]).eval()
228+
# array([-1.17591177, -0.06914345])
213229
```
214230

215231
## 3. Adding tests for the new `RandomVariable`
@@ -351,8 +367,40 @@ def test_blah_logcdf(self):
351367

352368
```
353369

370+
## 5. Adding tests for the `get_moment` method
371+
372+
Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
373+
which checks if:
374+
1. Moments return the `expected` values
375+
1. Moments have the expected size and shape
376+
377+
```python
378+
379+
import pytest
380+
from pymc.distributions import Blah
381+
382+
@pytest.mark.parametrize(
383+
"param1, param2, size, expected",
384+
[
385+
(0, 1, None, 0),
386+
(0, np.ones(5), None, np.zeros(5)),
387+
(np.arange(5), 1, None, np.arange(5)),
388+
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
389+
],
390+
)
391+
def test_blah_moment(param1, param2, size, expected):
392+
with Model() as model:
393+
Blah("x", param1=param1, param2=param2, size=size)
394+
assert_moment_is_expected(model, expected)
395+
396+
```
397+
398+
Here are some details worth keeping in mind:
399+
400+
1. In the case where you have to manually broadcast the parameters with each other it's important to add test conditions that would fail if you were not to do that. A straightforward way to do this is to make the used parameter a scalar, the unused one(s) a vector (one at a time) and size `None`.
401+
1. In other words, make sure to test different combinations of size and broadcasting to cover these cases.
354402

355-
## 5. Documenting the new `Distribution`
403+
## 6. Documenting the new `Distribution`
356404

357405
New distributions should have a rich docstring, following the same format as that of previously implemented distributions.
358406
It generally looks something like this:

0 commit comments

Comments
 (0)