|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | +kernelspec: |
| 8 | + display_name: pymc5recent |
| 9 | + language: python |
| 10 | + name: pymc5recent |
| 11 | +--- |
| 12 | + |
| 13 | +(faster_sampling_notebook)= |
| 14 | + |
| 15 | +# Faster Sampling with JAX and Numba |
| 16 | + |
| 17 | +:::{post} July 11, 2023 |
| 18 | +:tags: hierarchical model, JAX, numba, scaling |
| 19 | +:category: reference, intermediate |
| 20 | +:author: Thomas Wiecki |
| 21 | +::: |
| 22 | + |
| 23 | ++++ |
| 24 | + |
| 25 | +PyMC can compile its models to various execution backends through PyTensor, including: |
| 26 | +* C |
| 27 | +* JAX |
| 28 | +* Numba |
| 29 | + |
| 30 | +By default, PyMC is using the C backend which then gets called by the Python-based samplers. |
| 31 | + |
| 32 | +However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead. |
| 33 | + |
| 34 | +For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install `numpyro` and `blackjax`. Both of them are available through conda/mamba: `mamba install -c conda-forge numpyro blackjax`. |
| 35 | + |
| 36 | +For the Numba backend, there is the [Nutpie sampler](https://github.com/pymc-devs/nutpie) writte in Rust. To use this sampler you need `nutpie` installed: `mamba install -c conda-forge nutpie`. |
| 37 | + |
| 38 | +```{code-cell} ipython3 |
| 39 | +import arviz as az |
| 40 | +import matplotlib.pyplot as plt |
| 41 | +import numpy as np |
| 42 | +import pymc as pm |
| 43 | +
|
| 44 | +rng = np.random.default_rng(seed=42) |
| 45 | +print(f"Running on PyMC v{pm.__version__}") |
| 46 | +``` |
| 47 | + |
| 48 | +```{code-cell} ipython3 |
| 49 | +%config InlineBackend.figure_format = 'retina' |
| 50 | +az.style.use("arviz-darkgrid") |
| 51 | +``` |
| 52 | + |
| 53 | +We will use a simple probabilistic PCA model as our example. |
| 54 | + |
| 55 | +```{code-cell} ipython3 |
| 56 | +def build_toy_dataset(N, D, K, sigma=1): |
| 57 | + x_train = np.zeros((D, N)) |
| 58 | + w = rng.normal( |
| 59 | + 0.0, |
| 60 | + 2.0, |
| 61 | + size=(D, K), |
| 62 | + ) |
| 63 | + z = rng.normal(0.0, 1.0, size=(K, N)) |
| 64 | + mean = np.dot(w, z) |
| 65 | + for d in range(D): |
| 66 | + for n in range(N): |
| 67 | + x_train[d, n] = rng.normal(mean[d, n], sigma) |
| 68 | +
|
| 69 | + print("True principal axes:") |
| 70 | + print(w) |
| 71 | + return x_train |
| 72 | +
|
| 73 | +
|
| 74 | +N = 5000 # number of data points |
| 75 | +D = 2 # data dimensionality |
| 76 | +K = 1 # latent dimensionality |
| 77 | +
|
| 78 | +data = build_toy_dataset(N, D, K) |
| 79 | +``` |
| 80 | + |
| 81 | +```{code-cell} ipython3 |
| 82 | +plt.scatter(data[0, :], data[1, :], color="blue", alpha=0.1) |
| 83 | +plt.axis([-10, 10, -10, 10]) |
| 84 | +plt.title("Simulated data set") |
| 85 | +``` |
| 86 | + |
| 87 | +```{code-cell} ipython3 |
| 88 | +with pm.Model() as PPCA: |
| 89 | + w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered()) |
| 90 | + z = pm.Normal("z", mu=0, sigma=1, shape=[N, K]) |
| 91 | + x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data) |
| 92 | +``` |
| 93 | + |
| 94 | +## Sampling using Python NUTS sampler |
| 95 | + |
| 96 | +```{code-cell} ipython3 |
| 97 | +%%time |
| 98 | +with PPCA: |
| 99 | + idata_pymc = pm.sample() |
| 100 | +``` |
| 101 | + |
| 102 | +## Sampling using NumPyro JAX NUTS sampler |
| 103 | + |
| 104 | +```{code-cell} ipython3 |
| 105 | +%%time |
| 106 | +with PPCA: |
| 107 | + idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False) |
| 108 | +``` |
| 109 | + |
| 110 | +## Sampling using BlackJAX NUTS sampler |
| 111 | + |
| 112 | +```{code-cell} ipython3 |
| 113 | +%%time |
| 114 | +with PPCA: |
| 115 | + idata_blackjax = pm.sample(nuts_sampler="blackjax") |
| 116 | +``` |
| 117 | + |
| 118 | +## Sampling using Nutpie Rust NUTS sampler |
| 119 | + |
| 120 | +```{code-cell} ipython3 |
| 121 | +%%time |
| 122 | +with PPCA: |
| 123 | + idata_nutpie = pm.sample(nuts_sampler="nutpie") |
| 124 | +``` |
| 125 | + |
| 126 | +## Authors |
| 127 | +Authored by Thomas Wiecki in July 2023 |
| 128 | + |
| 129 | +```{code-cell} ipython3 |
| 130 | +%load_ext watermark |
| 131 | +%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie |
| 132 | +``` |
| 133 | + |
| 134 | +:::{include} ../page_footer.md |
| 135 | +::: |
0 commit comments