You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/contributing/developer_guide_implementing_distribution.md
+61-13Lines changed: 61 additions & 13 deletions
Original file line number
Diff line number
Diff line change
@@ -4,15 +4,16 @@ This guide provides an overview on how to implement a distribution for version 4
4
4
It is designed for developers who wish to add a new distribution to the library.
5
5
Users will not be aware of all this complexity and should instead make use of helper methods such as (TODO).
6
6
7
-
PyMC {class}`~pymc.distributions.Distribution` build on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implement `logp` and `logcdf` methods as well as other initialization and validation helpers, most notably `shape/dims`, alternative parametrizations, and default `transforms`.
7
+
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers.
8
+
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.
8
9
9
10
Here is a summary check-list of the steps needed to implement a new distribution.
10
11
Each section will be expanded below:
11
12
12
13
1. Creating a new `RandomVariable``Op`
13
14
1. Implementing the corresponding `Distribution` class
14
15
1. Adding tests for the new `RandomVariable`
15
-
1. Adding tests for the `logp` / `logcdf` methods
16
+
1. Adding tests for `logp` / `logcdf` and `get_moment` methods
16
17
1. Documenting the new `Distribution`.
17
18
18
19
This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions.
@@ -118,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in
118
119
PyMC 4.x works in a very {term}`functional <Functional Programming>` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together.
119
120
In practice, they take care of:
120
121
121
-
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding logp and logcdf methods.
122
+
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods.
122
123
1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer.
123
124
1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables.
124
125
1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of.
@@ -153,6 +154,14 @@ class Blah(PositiveContinuous):
153
154
# the rv_op needs in order to be instantiated
154
155
returnsuper().dist([param1, param2], **kwargs)
155
156
157
+
# get_moment returns a symbolic expression for the stable moment from which to start sampling
158
+
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`
159
+
defget_moment(rv, size, param1, param2):
160
+
moment, _ = at.broadcast_arrays(param1, param2)
161
+
ifnot rv_size_is_none(size):
162
+
moment = at.full(size, moment)
163
+
return moment
164
+
156
165
# Logp returns a symbolic expression for the logp evaluation of the variable
157
166
# given the `value` of the variable and the parameters `param1` ... `paramN`
158
167
deflogp(value, param1, param2):
@@ -189,27 +198,34 @@ Some notes:
189
198
overriding `__new__`.
190
199
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
191
200
1. The `logcdf` method is not a requirement, but it's a nice plus!
201
+
1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
202
+
1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
192
203
193
204
For a quick check that things are working you can try the following:
194
205
195
206
```python
196
207
197
208
import pymc as pm
209
+
from pymc.distributions.distribution import get_moment
198
210
199
-
# pm.blah = pm.Uniform in this example
200
-
blah = pm.Blah.dist([0, 0], [1, 2])
211
+
# pm.blah = pm.Normal in this example
212
+
blah = pm.blah.dist(mu=0, sigma=1)
201
213
202
214
# Test that the returned blah_op is still working fine
203
215
blah.eval()
204
-
# array([0.62778803, 1.95165513])
216
+
# array(-1.01397228)
217
+
218
+
# Test the get_moment method
219
+
get_moment(blah).eval()
220
+
# array(0.)
205
221
206
-
# Test the logp
207
-
pm.logp(blah, [1.5, 1.5]).eval()
208
-
# array([ -inf, -0.69314718])
222
+
# Test the logp method
223
+
pm.logp(blah, [-0.5, 1.5]).eval()
224
+
# array([-1.04393853, -2.04393853])
209
225
210
-
# Test the logcdf
211
-
pm.logcdf(blah, [1.5, 1.5]).eval()
212
-
# array([ 0. , -0.28768207])
226
+
# Test the logcdf method
227
+
pm.logcdf(blah, [-0.5, 1.5]).eval()
228
+
# array([-1.17591177, -0.06914345])
213
229
```
214
230
215
231
## 3. Adding tests for the new `RandomVariable`
@@ -351,8 +367,40 @@ def test_blah_logcdf(self):
351
367
352
368
```
353
369
370
+
## 5. Adding tests for the `get_moment` method
371
+
372
+
Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
1. In the case where you have to manually broadcast the parameters with each other it's important to add test conditions that would fail if you were not to do that. A straightforward way to do this is to make the used parameter a scalar, the unused one(s) a vector (one at a time) and size `None`.
401
+
1. In other words, make sure to test different combinations of size and broadcasting to cover these cases.
354
402
355
-
## 5. Documenting the new `Distribution`
403
+
## 6. Documenting the new `Distribution`
356
404
357
405
New distributions should have a rich docstring, following the same format as that of previously implemented distributions.
0 commit comments