Skip to content

Commit 07fac5e

Browse files
committed
refactor jax internals to support more kwargs to numpyro
1 parent 2e05854 commit 07fac5e

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

pymc/sampling/jax.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
find_observations,
4444
)
4545
from pymc.distributions.multivariate import PosDefMatrix
46-
from pymc.initial_point import StartDict
46+
from pymc.initial_point import PointType, StartDict
4747
from pymc.logprob.utils import CheckParameterValue
4848
from pymc.sampling.mcmc import _init_jitter
4949
from pymc.util import (
@@ -144,14 +144,16 @@ def get_jaxified_graph(
144144
return jax_funcify(fgraph)
145145

146146

147-
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
147+
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable[[PointType], jnp.ndarray]:
148148
model_logp = model.logp()
149149
if not negative_logp:
150150
model_logp = -model_logp
151151
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
152+
names = [v.name for v in model.value_vars]
152153

153-
def logp_fn_wrap(x):
154-
return logp_fn(*x)[0]
154+
def logp_fn_wrap(x: PointType) -> jnp.ndarray:
155+
p = [x[n] for n in names]
156+
return logp_fn(*p)[0]
155157

156158
return logp_fn_wrap
157159

@@ -182,22 +184,22 @@ def _device_put(input, device: str):
182184

183185

184186
def _postprocess_samples(
185-
jax_fn: Callable,
186-
raw_mcmc_samples: List[TensorVariable],
187+
jax_fn: Callable[[PointType], List[jnp.ndarray]],
188+
raw_mcmc_samples: PointType,
187189
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
188190
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
189-
) -> List[TensorVariable]:
191+
) -> List[jnp.ndarray]:
190192
if postprocessing_vectorize == "scan":
191-
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
193+
t_raw_mcmc_samples = {t: jnp.swapaxes(a, 0, 1) for t, a in raw_mcmc_samples.items()}
192194
jax_vfn = jax.vmap(jax_fn)
193195
_, outs = scan(
194-
lambda _, x: ((), jax_vfn(*x)),
196+
lambda _, x: ((), jax_vfn(x)),
195197
(),
196198
_device_put(t_raw_mcmc_samples, postprocessing_backend),
197199
)
198200
return [jnp.swapaxes(t, 0, 1) for t in outs]
199201
elif postprocessing_vectorize == "vmap":
200-
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
202+
return jax.vmap(jax.vmap(jax_fn))(_device_put(raw_mcmc_samples, postprocessing_backend))
201203
else:
202204
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")
203205

@@ -238,27 +240,56 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
238240

239241
def _get_log_likelihood(
240242
model: Model,
241-
samples,
243+
samples: PointType,
242244
backend: Optional[Literal["cpu", "gpu"]] = None,
243245
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
244-
) -> Dict:
246+
) -> Dict[str, jnp.ndarray]:
245247
"""Compute log-likelihood for all observations"""
246248
elemwise_logp = model.logp(model.observed_RVs, sum=False)
247249
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
250+
names = [v.name for v in model.value_vars]
251+
252+
def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]:
253+
p = [x[n] for n in names]
254+
return jax_fn(*p)
255+
248256
result = _postprocess_samples(
249-
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
257+
jax_fn_wrap, samples, backend, postprocessing_vectorize=postprocessing_vectorize
250258
)
251259
return {v.name: r for v, r in zip(model.observed_RVs, result)}
252260

253261

262+
def _get_transformed_values(
263+
model: Model,
264+
samples: PointType,
265+
vars_to_sample: List[str],
266+
backend: Optional[Literal["cpu", "gpu"]] = None,
267+
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
268+
) -> Dict[str, jnp.ndarray]:
269+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
270+
names = [v.name for v in model.value_vars]
271+
272+
def jax_fn_wrap(x: PointType) -> List[jnp.ndarray]:
273+
p = [x[n] for n in names]
274+
return jax_fn(*p)
275+
276+
result = _postprocess_samples(
277+
jax_fn_wrap,
278+
samples,
279+
postprocessing_backend=backend,
280+
postprocessing_vectorize=postprocessing_vectorize,
281+
)
282+
return {v.name: r for v, r in zip(vars_to_sample, result)}
283+
284+
254285
def _get_batched_jittered_initial_points(
255286
model: Model,
256287
chains: int,
257288
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
258289
random_seed: RandomSeed,
259290
jitter: bool = True,
260291
jitter_max_retries: int = 10,
261-
) -> Union[np.ndarray, List[np.ndarray]]:
292+
) -> Dict[str, np.ndarray]:
262293
"""Get jittered initial point in format expected by NumPyro MCMC kernel
263294
264295
Returns
@@ -275,10 +306,10 @@ def _get_batched_jittered_initial_points(
275306
jitter=jitter,
276307
jitter_max_retries=jitter_max_retries,
277308
)
278-
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
279309
if chains == 1:
280-
return initial_points_values[0]
281-
return [np.stack(init_state) for init_state in zip(*initial_points_values)]
310+
return initial_points[0]
311+
else:
312+
return {k: np.stack([ip[k] for ip in initial_points]) for k in initial_points[0].keys()}
282313

283314

284315
def _update_coords_and_dims(
@@ -420,7 +451,12 @@ def sample_blackjax_nuts(
420451
if var_names is None:
421452
var_names = model.unobserved_value_vars
422453

423-
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
454+
vars_to_sample = list(
455+
get_default_varnames(
456+
var_names,
457+
include_transformed=keep_untransformed,
458+
)
459+
)
424460

425461
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
426462

@@ -435,7 +471,7 @@ def sample_blackjax_nuts(
435471
)
436472

437473
if chains == 1:
438-
init_params = [np.stack(init_state) for init_state in zip(init_params)]
474+
init_params = {k: np.stack([v]) for k, v in init_params.items()}
439475

440476
logprob_fn = get_jaxified_logp(model)
441477

@@ -485,14 +521,14 @@ def sample_blackjax_nuts(
485521
logger.info(f"Sampling time = {tic3 - tic2}")
486522

487523
logger.info("Transforming variables...")
488-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
489-
result = _postprocess_samples(
490-
jax_fn,
491-
raw_mcmc_samples,
492-
postprocessing_backend=postprocessing_backend,
524+
525+
mcmc_samples = _get_transformed_values(
526+
model=model,
527+
vars_to_sample=vars_to_sample,
528+
samples=raw_mcmc_samples,
529+
backend=postprocessing_backend,
493530
postprocessing_vectorize=postprocessing_vectorize,
494531
)
495-
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
496532
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
497533
tic4 = datetime.now()
498534
logger.info(f"Transformation time = {tic4 - tic3}")
@@ -713,14 +749,13 @@ def sample_numpyro_nuts(
713749
logger.info(f"Sampling time = {tic3 - tic2}")
714750

715751
logger.info("Transforming variables...")
716-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
717-
result = _postprocess_samples(
718-
jax_fn,
719-
raw_mcmc_samples,
720-
postprocessing_backend=postprocessing_backend,
752+
mcmc_samples = _get_transformed_values(
753+
model=model,
754+
vars_to_sample=vars_to_sample,
755+
samples=raw_mcmc_samples,
756+
backend=postprocessing_backend,
721757
postprocessing_vectorize=postprocessing_vectorize,
722758
)
723-
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
724759

725760
tic4 = datetime.now()
726761
logger.info(f"Transformation time = {tic4 - tic3}")

tests/sampling/test_jax.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,22 @@ def test_idata_contains_stats(sampler_name: str):
459459
for stat_var, stat_var_dims in stat_vars.items():
460460
assert stat_var in stats.variables
461461
assert stats.get(stat_var).values.shape == stat_var_dims
462+
463+
464+
def test_sample_numpyro_nuts_block_adapt():
465+
with pm.Model(
466+
coords=dict(level=["Basement", "Floor"], county=[1, 2]),
467+
) as model:
468+
# multilevel modelling
469+
a = pm.Normal("a")
470+
s = pm.HalfNormal("s")
471+
a_g = pm.Normal("a_g", a, s, dims="level")
472+
s_g = pm.HalfNormal("s_g")
473+
a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
474+
trace = sample_numpyro_nuts(
475+
nuts_kwargs=dict(
476+
dense_mass=[
477+
("a", "a_g"),
478+
]
479+
)
480+
)

0 commit comments

Comments
 (0)