|
17 | 17 |
|
18 | 18 | from datetime import datetime
|
19 | 19 | from functools import partial
|
20 |
| -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union |
| 20 | +from typing import ( |
| 21 | + Any, |
| 22 | + Callable, |
| 23 | + Dict, |
| 24 | + List, |
| 25 | + Literal, |
| 26 | + Optional, |
| 27 | + Sequence, |
| 28 | + Union, |
| 29 | + overload, |
| 30 | +) |
21 | 31 |
|
22 | 32 | import arviz as az
|
23 | 33 | import jax
|
@@ -144,16 +154,45 @@ def get_jaxified_graph(
|
144 | 154 | return jax_funcify(fgraph)
|
145 | 155 |
|
146 | 156 |
|
147 |
| -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable[[PointType], jnp.ndarray]: |
| 157 | +@overload |
| 158 | +def get_jaxified_logp( |
| 159 | + model: Model, |
| 160 | + negative_logp: bool = ..., |
| 161 | + point_fn: Literal[False] = ..., |
| 162 | +) -> Callable[[Sequence[np.ndarray]], jnp.ndarray]: |
| 163 | + ... |
| 164 | + |
| 165 | + |
| 166 | +@overload |
| 167 | +def get_jaxified_logp( |
| 168 | + model: Model, |
| 169 | + negative_logp: bool = ..., |
| 170 | + point_fn: Literal[True] = ..., |
| 171 | +) -> Callable[[PointType], jnp.ndarray]: |
| 172 | + ... |
| 173 | + |
| 174 | + |
| 175 | +def get_jaxified_logp( |
| 176 | + model: Model, |
| 177 | + negative_logp: bool = True, |
| 178 | + point_fn: bool = False, |
| 179 | +) -> Union[Callable[[PointType], jnp.ndarray], Callable[[Sequence[np.ndarray]], jnp.ndarray]]: |
148 | 180 | model_logp = model.logp()
|
149 | 181 | if not negative_logp:
|
150 | 182 | model_logp = -model_logp
|
151 | 183 | logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
|
152 | 184 | names = [v.name for v in model.value_vars]
|
153 | 185 |
|
154 |
| - def logp_fn_wrap(x: PointType) -> jnp.ndarray: |
155 |
| - p = [x[n] for n in names] |
156 |
| - return logp_fn(*p)[0] |
| 186 | + if point_fn: |
| 187 | + |
| 188 | + def logp_fn_wrap(x: PointType) -> jnp.ndarray: |
| 189 | + p = [x[n] for n in names] |
| 190 | + return logp_fn(*p)[0] |
| 191 | + |
| 192 | + else: |
| 193 | + |
| 194 | + def logp_fn_wrap(x: Sequence[np.ndarray]) -> jnp.ndarray: |
| 195 | + return logp_fn(*x)[0] |
157 | 196 |
|
158 | 197 | return logp_fn_wrap
|
159 | 198 |
|
@@ -473,7 +512,7 @@ def sample_blackjax_nuts(
|
473 | 512 | if chains == 1:
|
474 | 513 | init_params = {k: np.stack([v]) for k, v in init_params.items()}
|
475 | 514 |
|
476 |
| - logprob_fn = get_jaxified_logp(model) |
| 515 | + logprob_fn = get_jaxified_logp(model, point_fn=True) |
477 | 516 |
|
478 | 517 | seed = jax.random.PRNGKey(random_seed)
|
479 | 518 | keys = jax.random.split(seed, chains)
|
@@ -702,7 +741,7 @@ def sample_numpyro_nuts(
|
702 | 741 | random_seed=random_seed,
|
703 | 742 | )
|
704 | 743 |
|
705 |
| - logp_fn = get_jaxified_logp(model, negative_logp=False) |
| 744 | + logp_fn = get_jaxified_logp(model, negative_logp=False, point_fn=True) |
706 | 745 |
|
707 | 746 | nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
|
708 | 747 | nuts_kernel = NUTS(
|
|
0 commit comments