From 07fac5e4faa3ac0b1365cbb7aff2e2c9fa45644f Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Wed, 6 Dec 2023 14:37:25 +0000 Subject: [PATCH 1/5] refactor jax internals to support more kwargs to numpyro --- pymc/sampling/jax.py | 97 ++++++++++++++++++++++++++------------ tests/sampling/test_jax.py | 19 ++++++++ 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 7268555f91..296ebb4a54 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -43,7 +43,7 @@ find_observations, ) from pymc.distributions.multivariate import PosDefMatrix -from pymc.initial_point import StartDict +from pymc.initial_point import PointType, StartDict from pymc.logprob.utils import CheckParameterValue from pymc.sampling.mcmc import _init_jitter from pymc.util import ( @@ -144,14 +144,16 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: +def get_jaxified_logp(model: Model, negative_logp=True) -> Callable[[PointType], jnp.ndarray]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) + names = [v.name for v in model.value_vars] - def logp_fn_wrap(x): - return logp_fn(*x)[0] + def logp_fn_wrap(x: PointType) -> jnp.ndarray: + p = [x[n] for n in names] + return logp_fn(*p)[0] return logp_fn_wrap @@ -182,22 +184,22 @@ def _device_put(input, device: str): def _postprocess_samples( - jax_fn: Callable, - raw_mcmc_samples: List[TensorVariable], + jax_fn: Callable[[PointType], List[jnp.ndarray]], + raw_mcmc_samples: PointType, postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", -) -> List[TensorVariable]: +) -> List[jnp.ndarray]: if postprocessing_vectorize == "scan": - t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples] + t_raw_mcmc_samples = {t: jnp.swapaxes(a, 0, 1) for t, a in raw_mcmc_samples.items()} jax_vfn = jax.vmap(jax_fn) _, outs = scan( - lambda _, x: ((), jax_vfn(*x)), + lambda _, x: ((), jax_vfn(x)), (), _device_put(t_raw_mcmc_samples, postprocessing_backend), ) return [jnp.swapaxes(t, 0, 1) for t in outs] elif postprocessing_vectorize == "vmap": - return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend)) + return jax.vmap(jax.vmap(jax_fn))(_device_put(raw_mcmc_samples, postprocessing_backend)) else: raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}") @@ -238,19 +240,48 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict: def _get_log_likelihood( model: Model, - samples, + samples: PointType, backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", -) -> Dict: +) -> Dict[str, jnp.ndarray]: """Compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) + names = [v.name for v in model.value_vars] + + def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]: + p = [x[n] for n in names] + return jax_fn(*p) + result = _postprocess_samples( - jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize + jax_fn_wrap, samples, backend, postprocessing_vectorize=postprocessing_vectorize ) return {v.name: r for v, r in zip(model.observed_RVs, result)} +def _get_transformed_values( + model: Model, + samples: PointType, + vars_to_sample: List[str], + backend: Optional[Literal["cpu", "gpu"]] = None, + postprocessing_vectorize: Literal["vmap", "scan"] = "scan", +) -> Dict[str, jnp.ndarray]: + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) + names = [v.name for v in model.value_vars] + + def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]: + p = [x[n] for n in names] + return jax_fn(*p) + + result = _postprocess_samples( + jax_fn_wrap, + samples, + postprocessing_backend=backend, + postprocessing_vectorize=postprocessing_vectorize, + ) + return {v.name: r for v, r in zip(vars_to_sample, result)} + + def _get_batched_jittered_initial_points( model: Model, chains: int, @@ -258,7 +289,7 @@ def _get_batched_jittered_initial_points( random_seed: RandomSeed, jitter: bool = True, jitter_max_retries: int = 10, -) -> Union[np.ndarray, List[np.ndarray]]: +) -> Dict[str, np.ndarray]: """Get jittered initial point in format expected by NumPyro MCMC kernel Returns @@ -275,10 +306,10 @@ def _get_batched_jittered_initial_points( jitter=jitter, jitter_max_retries=jitter_max_retries, ) - initial_points_values = [list(initial_point.values()) for initial_point in initial_points] if chains == 1: - return initial_points_values[0] - return [np.stack(init_state) for init_state in zip(*initial_points_values)] + return initial_points[0] + else: + return {k: np.stack([ip[k] for ip in initial_points]) for k in initial_points[0].keys()} def _update_coords_and_dims( @@ -420,7 +451,12 @@ def sample_blackjax_nuts( if var_names is None: var_names = model.unobserved_value_vars - vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + vars_to_sample = list( + get_default_varnames( + var_names, + include_transformed=keep_untransformed, + ) + ) (random_seed,) = _get_seeds_per_chain(random_seed, 1) @@ -435,7 +471,7 @@ def sample_blackjax_nuts( ) if chains == 1: - init_params = [np.stack(init_state) for init_state in zip(init_params)] + init_params = {k: np.stack([v]) for k, v in init_params.items()} logprob_fn = get_jaxified_logp(model) @@ -485,14 +521,14 @@ def sample_blackjax_nuts( logger.info(f"Sampling time = {tic3 - tic2}") logger.info("Transforming variables...") - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, + + mcmc_samples = _get_transformed_values( + model=model, + vars_to_sample=vars_to_sample, + samples=raw_mcmc_samples, + backend=postprocessing_backend, postprocessing_vectorize=postprocessing_vectorize, ) - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy) tic4 = datetime.now() logger.info(f"Transformation time = {tic4 - tic3}") @@ -713,14 +749,13 @@ def sample_numpyro_nuts( logger.info(f"Sampling time = {tic3 - tic2}") logger.info("Transforming variables...") - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, + mcmc_samples = _get_transformed_values( + model=model, + vars_to_sample=vars_to_sample, + samples=raw_mcmc_samples, + backend=postprocessing_backend, postprocessing_vectorize=postprocessing_vectorize, ) - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} tic4 = datetime.now() logger.info(f"Transformation time = {tic4 - tic3}") diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 8ec95552a0..1c2cfc9b0f 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -459,3 +459,22 @@ def test_idata_contains_stats(sampler_name: str): for stat_var, stat_var_dims in stat_vars.items(): assert stat_var in stats.variables assert stats.get(stat_var).values.shape == stat_var_dims + + +def test_sample_numpyro_nuts_block_adapt(): + with pm.Model( + coords=dict(level=["Basement", "Floor"], county=[1, 2]), + ) as model: + # multilevel modelling + a = pm.Normal("a") + s = pm.HalfNormal("s") + a_g = pm.Normal("a_g", a, s, dims="level") + s_g = pm.HalfNormal("s_g") + a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level")) + trace = sample_numpyro_nuts( + nuts_kwargs=dict( + dense_mass=[ + ("a", "a_g"), + ] + ) + ) From aa552fcc1bbd62e5b171c3d30bfd70c0a5a03f2f Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Wed, 6 Dec 2023 15:02:01 +0000 Subject: [PATCH 2/5] fix tests --- tests/sampling/test_jax.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 1c2cfc9b0f..9349502b9d 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -185,7 +185,7 @@ def test_get_log_likelihood(): b_true = trace.log_likelihood.b.values a = np.array(trace.posterior.a) sigma_log_ = np.log(np.array(trace.posterior.sigma)) - b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] + b_jax = _get_log_likelihood(model, {"a": a, "sigma_log__": sigma_log_})["b"] assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) @@ -215,7 +215,7 @@ def test_get_jaxified_logp(): jax_fn = get_jaxified_logp(m) # This would underflow if not optimized - assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0)))) + assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0)))) @pytest.fixture(scope="module") @@ -302,19 +302,19 @@ def test_get_batched_jittered_initial_points(): ips = _get_batched_jittered_initial_points( model=model, chains=1, random_seed=1, initvals=None, jitter=False ) - assert np.all(ips[0] == 0) + assert np.all(ips["x"] == 0) # Single chain ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None) - assert ips[0].shape == (2, 3) - assert np.all(ips[0] != 0) + assert ips["x"].shape == (2, 3) + assert np.all(ips["x"] != 0) # Multiple chains ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None) - assert ips[0].shape == (2, 2, 3) - assert np.all(ips[0][0] != ips[0][1]) + assert ips["x"].shape == (2, 2, 3) + assert np.all(ips["x"][0] != ips["x"][1]) @pytest.mark.parametrize( From b02e045f25e1f51910103aa01f09d7e6cb7f5ff6 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Sat, 9 Dec 2023 16:19:59 +0000 Subject: [PATCH 3/5] revert breaking change to have the default behaviour --- pymc/sampling/jax.py | 53 +++++++++++++++++++++++++++++++++----- tests/sampling/test_jax.py | 5 ++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 296ebb4a54..6aba501d62 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -17,7 +17,17 @@ from datetime import datetime from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, + overload, +) import arviz as az import jax @@ -144,16 +154,45 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable[[PointType], jnp.ndarray]: +@overload +def get_jaxified_logp( + model: Model, + negative_logp: bool = ..., + point_fn: Literal[False] = ..., +) -> Callable[[Sequence[np.ndarray]], jnp.ndarray]: + ... + + +@overload +def get_jaxified_logp( + model: Model, + negative_logp: bool = ..., + point_fn: Literal[True] = ..., +) -> Callable[[PointType], jnp.ndarray]: + ... + + +def get_jaxified_logp( + model: Model, + negative_logp: bool = True, + point_fn: bool = False, +) -> Union[Callable[[PointType], jnp.ndarray], Callable[[Sequence[np.ndarray]], jnp.ndarray]]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) names = [v.name for v in model.value_vars] - def logp_fn_wrap(x: PointType) -> jnp.ndarray: - p = [x[n] for n in names] - return logp_fn(*p)[0] + if point_fn: + + def logp_fn_wrap(x: PointType) -> jnp.ndarray: + p = [x[n] for n in names] + return logp_fn(*p)[0] + + else: + + def logp_fn_wrap(x: Sequence[np.ndarray]) -> jnp.ndarray: + return logp_fn(*x)[0] return logp_fn_wrap @@ -473,7 +512,7 @@ def sample_blackjax_nuts( if chains == 1: init_params = {k: np.stack([v]) for k, v in init_params.items()} - logprob_fn = get_jaxified_logp(model) + logprob_fn = get_jaxified_logp(model, point_fn=True) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) @@ -702,7 +741,7 @@ def sample_numpyro_nuts( random_seed=random_seed, ) - logp_fn = get_jaxified_logp(model, negative_logp=False) + logp_fn = get_jaxified_logp(model, negative_logp=False, point_fn=True) nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs) nuts_kernel = NUTS( diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 9349502b9d..f8b230c80f 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -217,6 +217,11 @@ def test_get_jaxified_logp(): # This would underflow if not optimized assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0)))) + # by default return array fn + jax_fn = get_jaxified_logp(m, point_fn=True) + # This would underflow if not optimized + assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0)))) + @pytest.fixture(scope="module") def model_test_idata_kwargs() -> pm.Model: From 759c00e7fe7fd71d36b800150f293c25d412eef9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 11 Dec 2023 16:54:10 +0300 Subject: [PATCH 4/5] Update tests/sampling/test_jax.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/sampling/test_jax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index f8b230c80f..203564bd7a 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -477,6 +477,8 @@ def test_sample_numpyro_nuts_block_adapt(): s_g = pm.HalfNormal("s_g") a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level")) trace = sample_numpyro_nuts( + tune=10, + draws=10, nuts_kwargs=dict( dense_mass=[ ("a", "a_g"), From 8eb42841e9438b095479c8106d19708fb3ef44f8 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Mon, 11 Dec 2023 14:21:17 +0000 Subject: [PATCH 5/5] lint --- tests/sampling/test_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 203564bd7a..996bbd9c88 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -483,5 +483,5 @@ def test_sample_numpyro_nuts_block_adapt(): dense_mass=[ ("a", "a_g"), ] - ) + ), )