Skip to content

Commit b02e045

Browse files
committed
revert breaking change to have the default behaviour
1 parent aa552fc commit b02e045

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

pymc/sampling/jax.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717

1818
from datetime import datetime
1919
from functools import partial
20-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
20+
from typing import (
21+
Any,
22+
Callable,
23+
Dict,
24+
List,
25+
Literal,
26+
Optional,
27+
Sequence,
28+
Union,
29+
overload,
30+
)
2131

2232
import arviz as az
2333
import jax
@@ -144,16 +154,45 @@ def get_jaxified_graph(
144154
return jax_funcify(fgraph)
145155

146156

147-
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable[[PointType], jnp.ndarray]:
157+
@overload
158+
def get_jaxified_logp(
159+
model: Model,
160+
negative_logp: bool = ...,
161+
point_fn: Literal[False] = ...,
162+
) -> Callable[[Sequence[np.ndarray]], jnp.ndarray]:
163+
...
164+
165+
166+
@overload
167+
def get_jaxified_logp(
168+
model: Model,
169+
negative_logp: bool = ...,
170+
point_fn: Literal[True] = ...,
171+
) -> Callable[[PointType], jnp.ndarray]:
172+
...
173+
174+
175+
def get_jaxified_logp(
176+
model: Model,
177+
negative_logp: bool = True,
178+
point_fn: bool = False,
179+
) -> Union[Callable[[PointType], jnp.ndarray], Callable[[Sequence[np.ndarray]], jnp.ndarray]]:
148180
model_logp = model.logp()
149181
if not negative_logp:
150182
model_logp = -model_logp
151183
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
152184
names = [v.name for v in model.value_vars]
153185

154-
def logp_fn_wrap(x: PointType) -> jnp.ndarray:
155-
p = [x[n] for n in names]
156-
return logp_fn(*p)[0]
186+
if point_fn:
187+
188+
def logp_fn_wrap(x: PointType) -> jnp.ndarray:
189+
p = [x[n] for n in names]
190+
return logp_fn(*p)[0]
191+
192+
else:
193+
194+
def logp_fn_wrap(x: Sequence[np.ndarray]) -> jnp.ndarray:
195+
return logp_fn(*x)[0]
157196

158197
return logp_fn_wrap
159198

@@ -473,7 +512,7 @@ def sample_blackjax_nuts(
473512
if chains == 1:
474513
init_params = {k: np.stack([v]) for k, v in init_params.items()}
475514

476-
logprob_fn = get_jaxified_logp(model)
515+
logprob_fn = get_jaxified_logp(model, point_fn=True)
477516

478517
seed = jax.random.PRNGKey(random_seed)
479518
keys = jax.random.split(seed, chains)
@@ -702,7 +741,7 @@ def sample_numpyro_nuts(
702741
random_seed=random_seed,
703742
)
704743

705-
logp_fn = get_jaxified_logp(model, negative_logp=False)
744+
logp_fn = get_jaxified_logp(model, negative_logp=False, point_fn=True)
706745

707746
nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
708747
nuts_kernel = NUTS(

tests/sampling/test_jax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ def test_get_jaxified_logp():
217217
# This would underflow if not optimized
218218
assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0))))
219219

220+
# by default return array fn
221+
jax_fn = get_jaxified_logp(m, point_fn=True)
222+
# This would underflow if not optimized
223+
assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0))))
224+
220225

221226
@pytest.fixture(scope="module")
222227
def model_test_idata_kwargs() -> pm.Model:

0 commit comments

Comments
 (0)