Skip to content

Refactor jax internals to support dense_mass kwarg for numpyro #7050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 108 additions & 34 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@

from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
overload,
)

import arviz as az
import jax
Expand All @@ -43,7 +53,7 @@
find_observations,
)
from pymc.distributions.multivariate import PosDefMatrix
from pymc.initial_point import StartDict
from pymc.initial_point import PointType, StartDict
from pymc.logprob.utils import CheckParameterValue
from pymc.sampling.mcmc import _init_jitter
from pymc.util import (
Expand Down Expand Up @@ -144,14 +154,45 @@ def get_jaxified_graph(
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
@overload
def get_jaxified_logp(
model: Model,
negative_logp: bool = ...,
point_fn: Literal[False] = ...,
) -> Callable[[Sequence[np.ndarray]], jnp.ndarray]:
...


@overload
def get_jaxified_logp(
model: Model,
negative_logp: bool = ...,
point_fn: Literal[True] = ...,
) -> Callable[[PointType], jnp.ndarray]:
...


def get_jaxified_logp(
model: Model,
negative_logp: bool = True,
point_fn: bool = False,
) -> Union[Callable[[PointType], jnp.ndarray], Callable[[Sequence[np.ndarray]], jnp.ndarray]]:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
names = [v.name for v in model.value_vars]

def logp_fn_wrap(x):
return logp_fn(*x)[0]
if point_fn:

def logp_fn_wrap(x: PointType) -> jnp.ndarray:
p = [x[n] for n in names]
return logp_fn(*p)[0]

else:

def logp_fn_wrap(x: Sequence[np.ndarray]) -> jnp.ndarray:
return logp_fn(*x)[0]

return logp_fn_wrap

Expand Down Expand Up @@ -182,22 +223,22 @@ def _device_put(input, device: str):


def _postprocess_samples(
jax_fn: Callable,
raw_mcmc_samples: List[TensorVariable],
jax_fn: Callable[[PointType], List[jnp.ndarray]],
raw_mcmc_samples: PointType,
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
) -> List[TensorVariable]:
) -> List[jnp.ndarray]:
if postprocessing_vectorize == "scan":
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
t_raw_mcmc_samples = {t: jnp.swapaxes(a, 0, 1) for t, a in raw_mcmc_samples.items()}
jax_vfn = jax.vmap(jax_fn)
_, outs = scan(
lambda _, x: ((), jax_vfn(*x)),
lambda _, x: ((), jax_vfn(x)),
(),
_device_put(t_raw_mcmc_samples, postprocessing_backend),
)
return [jnp.swapaxes(t, 0, 1) for t in outs]
elif postprocessing_vectorize == "vmap":
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
return jax.vmap(jax.vmap(jax_fn))(_device_put(raw_mcmc_samples, postprocessing_backend))
else:
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

Expand Down Expand Up @@ -238,27 +279,56 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:

def _get_log_likelihood(
model: Model,
samples,
samples: PointType,
backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
) -> Dict:
) -> Dict[str, jnp.ndarray]:
"""Compute log-likelihood for all observations"""
elemwise_logp = model.logp(model.observed_RVs, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
names = [v.name for v in model.value_vars]

def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]:
p = [x[n] for n in names]
return jax_fn(*p)

result = _postprocess_samples(
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
jax_fn_wrap, samples, backend, postprocessing_vectorize=postprocessing_vectorize
)
return {v.name: r for v, r in zip(model.observed_RVs, result)}


def _get_transformed_values(
model: Model,
samples: PointType,
vars_to_sample: List[str],
backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
) -> Dict[str, jnp.ndarray]:
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
names = [v.name for v in model.value_vars]

def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]:
p = [x[n] for n in names]
return jax_fn(*p)

result = _postprocess_samples(
jax_fn_wrap,
samples,
postprocessing_backend=backend,
postprocessing_vectorize=postprocessing_vectorize,
)
return {v.name: r for v, r in zip(vars_to_sample, result)}


def _get_batched_jittered_initial_points(
model: Model,
chains: int,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
random_seed: RandomSeed,
jitter: bool = True,
jitter_max_retries: int = 10,
) -> Union[np.ndarray, List[np.ndarray]]:
) -> Dict[str, np.ndarray]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel

Returns
Expand All @@ -275,10 +345,10 @@ def _get_batched_jittered_initial_points(
jitter=jitter,
jitter_max_retries=jitter_max_retries,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
return initial_points_values[0]
return [np.stack(init_state) for init_state in zip(*initial_points_values)]
return initial_points[0]
else:
return {k: np.stack([ip[k] for ip in initial_points]) for k in initial_points[0].keys()}


def _update_coords_and_dims(
Expand Down Expand Up @@ -420,7 +490,12 @@ def sample_blackjax_nuts(
if var_names is None:
var_names = model.unobserved_value_vars

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
vars_to_sample = list(
get_default_varnames(
var_names,
include_transformed=keep_untransformed,
)
)

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

Expand All @@ -435,9 +510,9 @@ def sample_blackjax_nuts(
)

if chains == 1:
init_params = [np.stack(init_state) for init_state in zip(init_params)]
init_params = {k: np.stack([v]) for k, v in init_params.items()}

logprob_fn = get_jaxified_logp(model)
logprob_fn = get_jaxified_logp(model, point_fn=True)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)
Expand Down Expand Up @@ -485,14 +560,14 @@ def sample_blackjax_nuts(
logger.info(f"Sampling time = {tic3 - tic2}")

logger.info("Transforming variables...")
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,

mcmc_samples = _get_transformed_values(
model=model,
vars_to_sample=vars_to_sample,
samples=raw_mcmc_samples,
backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
tic4 = datetime.now()
logger.info(f"Transformation time = {tic4 - tic3}")
Expand Down Expand Up @@ -666,7 +741,7 @@ def sample_numpyro_nuts(
random_seed=random_seed,
)

logp_fn = get_jaxified_logp(model, negative_logp=False)
logp_fn = get_jaxified_logp(model, negative_logp=False, point_fn=True)

nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
nuts_kernel = NUTS(
Expand Down Expand Up @@ -713,14 +788,13 @@ def sample_numpyro_nuts(
logger.info(f"Sampling time = {tic3 - tic2}")

logger.info("Transforming variables...")
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,
mcmc_samples = _get_transformed_values(
model=model,
vars_to_sample=vars_to_sample,
samples=raw_mcmc_samples,
backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

tic4 = datetime.now()
logger.info(f"Transformation time = {tic4 - tic3}")
Expand Down
40 changes: 33 additions & 7 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_get_log_likelihood():
b_true = trace.log_likelihood.b.values
a = np.array(trace.posterior.a)
sigma_log_ = np.log(np.array(trace.posterior.sigma))
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]
b_jax = _get_log_likelihood(model, {"a": a, "sigma_log__": sigma_log_})["b"]

assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))

Expand Down Expand Up @@ -215,7 +215,12 @@ def test_get_jaxified_logp():

jax_fn = get_jaxified_logp(m)
# This would underflow if not optimized
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))
assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0))))

# by default return array fn
jax_fn = get_jaxified_logp(m, point_fn=True)
# This would underflow if not optimized
assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0))))


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -302,19 +307,19 @@ def test_get_batched_jittered_initial_points():
ips = _get_batched_jittered_initial_points(
model=model, chains=1, random_seed=1, initvals=None, jitter=False
)
assert np.all(ips[0] == 0)
assert np.all(ips["x"] == 0)

# Single chain
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)

assert ips[0].shape == (2, 3)
assert np.all(ips[0] != 0)
assert ips["x"].shape == (2, 3)
assert np.all(ips["x"] != 0)

# Multiple chains
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)

assert ips[0].shape == (2, 2, 3)
assert np.all(ips[0][0] != ips[0][1])
assert ips["x"].shape == (2, 2, 3)
assert np.all(ips["x"][0] != ips["x"][1])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -459,3 +464,24 @@ def test_idata_contains_stats(sampler_name: str):
for stat_var, stat_var_dims in stat_vars.items():
assert stat_var in stats.variables
assert stats.get(stat_var).values.shape == stat_var_dims


def test_sample_numpyro_nuts_block_adapt():
with pm.Model(
coords=dict(level=["Basement", "Floor"], county=[1, 2]),
) as model:
# multilevel modelling
a = pm.Normal("a")
s = pm.HalfNormal("s")
a_g = pm.Normal("a_g", a, s, dims="level")
s_g = pm.HalfNormal("s_g")
a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
trace = sample_numpyro_nuts(
tune=10,
draws=10,
nuts_kwargs=dict(
dense_mass=[
("a", "a_g"),
]
),
)