Skip to content

Commit d123659

Browse files
authored
External nuts sampler (#560)
* Update NB to use new nuts_sampler kwarg and PCA example. * Re-execute. * Re-execute. * Re-execute. * Re-execute. * Remove old JAX files. * Fixes suggested by Oriol.
1 parent 790939a commit d123659

File tree

4 files changed

+651
-513
lines changed

4 files changed

+651
-513
lines changed

examples/samplers/GLM-hierarchical-jax.ipynb

Lines changed: 0 additions & 384 deletions
This file was deleted.

examples/samplers/GLM-hierarchical-jax.myst.md

Lines changed: 0 additions & 129 deletions
This file was deleted.

examples/samplers/fast_sampling_with_jax_and_numba.ipynb

Lines changed: 516 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
kernelspec:
8+
display_name: pymc5recent
9+
language: python
10+
name: pymc5recent
11+
---
12+
13+
(faster_sampling_notebook)=
14+
15+
# Faster Sampling with JAX and Numba
16+
17+
:::{post} July 11, 2023
18+
:tags: hierarchical model, JAX, numba, scaling
19+
:category: reference, intermediate
20+
:author: Thomas Wiecki
21+
:::
22+
23+
+++
24+
25+
PyMC can compile its models to various execution backends through PyTensor, including:
26+
* C
27+
* JAX
28+
* Numba
29+
30+
By default, PyMC is using the C backend which then gets called by the Python-based samplers.
31+
32+
However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead.
33+
34+
For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install `numpyro` and `blackjax`. Both of them are available through conda/mamba: `mamba install -c conda-forge numpyro blackjax`.
35+
36+
For the Numba backend, there is the [Nutpie sampler](https://github.com/pymc-devs/nutpie) writte in Rust. To use this sampler you need `nutpie` installed: `mamba install -c conda-forge nutpie`.
37+
38+
```{code-cell} ipython3
39+
import arviz as az
40+
import matplotlib.pyplot as plt
41+
import numpy as np
42+
import pymc as pm
43+
44+
rng = np.random.default_rng(seed=42)
45+
print(f"Running on PyMC v{pm.__version__}")
46+
```
47+
48+
```{code-cell} ipython3
49+
%config InlineBackend.figure_format = 'retina'
50+
az.style.use("arviz-darkgrid")
51+
```
52+
53+
We will use a simple probabilistic PCA model as our example.
54+
55+
```{code-cell} ipython3
56+
def build_toy_dataset(N, D, K, sigma=1):
57+
x_train = np.zeros((D, N))
58+
w = rng.normal(
59+
0.0,
60+
2.0,
61+
size=(D, K),
62+
)
63+
z = rng.normal(0.0, 1.0, size=(K, N))
64+
mean = np.dot(w, z)
65+
for d in range(D):
66+
for n in range(N):
67+
x_train[d, n] = rng.normal(mean[d, n], sigma)
68+
69+
print("True principal axes:")
70+
print(w)
71+
return x_train
72+
73+
74+
N = 5000 # number of data points
75+
D = 2 # data dimensionality
76+
K = 1 # latent dimensionality
77+
78+
data = build_toy_dataset(N, D, K)
79+
```
80+
81+
```{code-cell} ipython3
82+
plt.scatter(data[0, :], data[1, :], color="blue", alpha=0.1)
83+
plt.axis([-10, 10, -10, 10])
84+
plt.title("Simulated data set")
85+
```
86+
87+
```{code-cell} ipython3
88+
with pm.Model() as PPCA:
89+
w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
90+
z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
91+
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
92+
```
93+
94+
## Sampling using Python NUTS sampler
95+
96+
```{code-cell} ipython3
97+
%%time
98+
with PPCA:
99+
idata_pymc = pm.sample()
100+
```
101+
102+
## Sampling using NumPyro JAX NUTS sampler
103+
104+
```{code-cell} ipython3
105+
%%time
106+
with PPCA:
107+
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
108+
```
109+
110+
## Sampling using BlackJAX NUTS sampler
111+
112+
```{code-cell} ipython3
113+
%%time
114+
with PPCA:
115+
idata_blackjax = pm.sample(nuts_sampler="blackjax")
116+
```
117+
118+
## Sampling using Nutpie Rust NUTS sampler
119+
120+
```{code-cell} ipython3
121+
%%time
122+
with PPCA:
123+
idata_nutpie = pm.sample(nuts_sampler="nutpie")
124+
```
125+
126+
## Authors
127+
Authored by Thomas Wiecki in July 2023
128+
129+
```{code-cell} ipython3
130+
%load_ext watermark
131+
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
132+
```
133+
134+
:::{include} ../page_footer.md
135+
:::

0 commit comments

Comments
 (0)