43
43
find_observations ,
44
44
)
45
45
from pymc .distributions .multivariate import PosDefMatrix
46
- from pymc .initial_point import StartDict
46
+ from pymc .initial_point import PointType , StartDict
47
47
from pymc .logprob .utils import CheckParameterValue
48
48
from pymc .sampling .mcmc import _init_jitter
49
49
from pymc .util import (
@@ -144,14 +144,16 @@ def get_jaxified_graph(
144
144
return jax_funcify (fgraph )
145
145
146
146
147
- def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable :
147
+ def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable [[ PointType ], jnp . ndarray ] :
148
148
model_logp = model .logp ()
149
149
if not negative_logp :
150
150
model_logp = - model_logp
151
151
logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logp ])
152
+ names = [v .name for v in model .value_vars ]
152
153
153
- def logp_fn_wrap (x ):
154
- return logp_fn (* x )[0 ]
154
+ def logp_fn_wrap (x : PointType ) -> jnp .ndarray :
155
+ p = [x [n ] for n in names ]
156
+ return logp_fn (* p )[0 ]
155
157
156
158
return logp_fn_wrap
157
159
@@ -182,22 +184,22 @@ def _device_put(input, device: str):
182
184
183
185
184
186
def _postprocess_samples (
185
- jax_fn : Callable ,
186
- raw_mcmc_samples : List [ TensorVariable ] ,
187
+ jax_fn : Callable [[ PointType ], List [ jnp . ndarray ]] ,
188
+ raw_mcmc_samples : PointType ,
187
189
postprocessing_backend : Optional [Literal ["cpu" , "gpu" ]] = None ,
188
190
postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
189
- ) -> List [TensorVariable ]:
191
+ ) -> List [jnp . ndarray ]:
190
192
if postprocessing_vectorize == "scan" :
191
- t_raw_mcmc_samples = [ jnp .swapaxes (t , 0 , 1 ) for t in raw_mcmc_samples ]
193
+ t_raw_mcmc_samples = { t : jnp .swapaxes (a , 0 , 1 ) for t , a in raw_mcmc_samples . items ()}
192
194
jax_vfn = jax .vmap (jax_fn )
193
195
_ , outs = scan (
194
- lambda _ , x : ((), jax_vfn (* x )),
196
+ lambda _ , x : ((), jax_vfn (x )),
195
197
(),
196
198
_device_put (t_raw_mcmc_samples , postprocessing_backend ),
197
199
)
198
200
return [jnp .swapaxes (t , 0 , 1 ) for t in outs ]
199
201
elif postprocessing_vectorize == "vmap" :
200
- return jax .vmap (jax .vmap (jax_fn ))(* _device_put (raw_mcmc_samples , postprocessing_backend ))
202
+ return jax .vmap (jax .vmap (jax_fn ))(_device_put (raw_mcmc_samples , postprocessing_backend ))
201
203
else :
202
204
raise ValueError (f"Unrecognized postprocessing_vectorize: { postprocessing_vectorize } " )
203
205
@@ -238,27 +240,56 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
238
240
239
241
def _get_log_likelihood (
240
242
model : Model ,
241
- samples ,
243
+ samples : PointType ,
242
244
backend : Optional [Literal ["cpu" , "gpu" ]] = None ,
243
245
postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
244
- ) -> Dict :
246
+ ) -> Dict [ str , jnp . ndarray ] :
245
247
"""Compute log-likelihood for all observations"""
246
248
elemwise_logp = model .logp (model .observed_RVs , sum = False )
247
249
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
250
+ names = [v .name for v in model .value_vars ]
251
+
252
+ def jax_fn_wrap (x : PointType ) -> List [jnp .ndarray ]:
253
+ p = [x [n ] for n in names ]
254
+ return jax_fn (* p )
255
+
248
256
result = _postprocess_samples (
249
- jax_fn , samples , backend , postprocessing_vectorize = postprocessing_vectorize
257
+ jax_fn_wrap , samples , backend , postprocessing_vectorize = postprocessing_vectorize
250
258
)
251
259
return {v .name : r for v , r in zip (model .observed_RVs , result )}
252
260
253
261
262
+ def _get_transformed_values (
263
+ model : Model ,
264
+ samples : PointType ,
265
+ vars_to_sample : List [str ],
266
+ backend : Optional [Literal ["cpu" , "gpu" ]] = None ,
267
+ postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
268
+ ) -> Dict [str , jnp .ndarray ]:
269
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
270
+ names = [v .name for v in model .value_vars ]
271
+
272
+ def jax_fn_wrap (x : PointType ) -> List [jnp .ndarray ]:
273
+ p = [x [n ] for n in names ]
274
+ return jax_fn (* p )
275
+
276
+ result = _postprocess_samples (
277
+ jax_fn_wrap ,
278
+ samples ,
279
+ postprocessing_backend = backend ,
280
+ postprocessing_vectorize = postprocessing_vectorize ,
281
+ )
282
+ return {v .name : r for v , r in zip (vars_to_sample , result )}
283
+
284
+
254
285
def _get_batched_jittered_initial_points (
255
286
model : Model ,
256
287
chains : int ,
257
288
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]],
258
289
random_seed : RandomSeed ,
259
290
jitter : bool = True ,
260
291
jitter_max_retries : int = 10 ,
261
- ) -> Union [ np . ndarray , List [ np .ndarray ] ]:
292
+ ) -> Dict [ str , np .ndarray ]:
262
293
"""Get jittered initial point in format expected by NumPyro MCMC kernel
263
294
264
295
Returns
@@ -275,10 +306,10 @@ def _get_batched_jittered_initial_points(
275
306
jitter = jitter ,
276
307
jitter_max_retries = jitter_max_retries ,
277
308
)
278
- initial_points_values = [list (initial_point .values ()) for initial_point in initial_points ]
279
309
if chains == 1 :
280
- return initial_points_values [0 ]
281
- return [np .stack (init_state ) for init_state in zip (* initial_points_values )]
310
+ return initial_points [0 ]
311
+ else :
312
+ return {k : np .stack ([ip [k ] for ip in initial_points ]) for k in initial_points [0 ].keys ()}
282
313
283
314
284
315
def _update_coords_and_dims (
@@ -420,7 +451,12 @@ def sample_blackjax_nuts(
420
451
if var_names is None :
421
452
var_names = model .unobserved_value_vars
422
453
423
- vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
454
+ vars_to_sample = list (
455
+ get_default_varnames (
456
+ var_names ,
457
+ include_transformed = keep_untransformed ,
458
+ )
459
+ )
424
460
425
461
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
426
462
@@ -435,7 +471,7 @@ def sample_blackjax_nuts(
435
471
)
436
472
437
473
if chains == 1 :
438
- init_params = [ np .stack (init_state ) for init_state in zip ( init_params )]
474
+ init_params = { k : np .stack ([ v ] ) for k , v in init_params . items ()}
439
475
440
476
logprob_fn = get_jaxified_logp (model )
441
477
@@ -485,14 +521,14 @@ def sample_blackjax_nuts(
485
521
logger .info (f"Sampling time = { tic3 - tic2 } " )
486
522
487
523
logger .info ("Transforming variables..." )
488
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
489
- result = _postprocess_samples (
490
- jax_fn ,
491
- raw_mcmc_samples ,
492
- postprocessing_backend = postprocessing_backend ,
524
+
525
+ mcmc_samples = _get_transformed_values (
526
+ model = model ,
527
+ vars_to_sample = vars_to_sample ,
528
+ samples = raw_mcmc_samples ,
529
+ backend = postprocessing_backend ,
493
530
postprocessing_vectorize = postprocessing_vectorize ,
494
531
)
495
- mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
496
532
mcmc_stats = _blackjax_stats_to_dict (stats , potential_energy )
497
533
tic4 = datetime .now ()
498
534
logger .info (f"Transformation time = { tic4 - tic3 } " )
@@ -713,14 +749,13 @@ def sample_numpyro_nuts(
713
749
logger .info (f"Sampling time = { tic3 - tic2 } " )
714
750
715
751
logger .info ("Transforming variables..." )
716
- jax_fn = get_jaxified_graph ( inputs = model . value_vars , outputs = vars_to_sample )
717
- result = _postprocess_samples (
718
- jax_fn ,
719
- raw_mcmc_samples ,
720
- postprocessing_backend = postprocessing_backend ,
752
+ mcmc_samples = _get_transformed_values (
753
+ model = model ,
754
+ vars_to_sample = vars_to_sample ,
755
+ samples = raw_mcmc_samples ,
756
+ backend = postprocessing_backend ,
721
757
postprocessing_vectorize = postprocessing_vectorize ,
722
758
)
723
- mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
724
759
725
760
tic4 = datetime .now ()
726
761
logger .info (f"Transformation time = { tic4 - tic3 } " )
0 commit comments