Skip to content

Commit 1fd7a11

Browse files
committed
Added placeholder/reminder to remove jax dependency when converting trace data to InferenceData
1 parent fdc3f38 commit 1fd7a11

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

pymc_experimental/inference/pathfinder/pathfinder.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def convert_flat_trace_to_idata(
122122
samples,
123123
include_transformed=False,
124124
postprocessing_backend="cpu",
125+
inference_backend="pymc",
125126
model=None,
126127
):
127128
model = modelcontext(model)
@@ -139,10 +140,21 @@ def convert_flat_trace_to_idata(
139140
var_names = model.unobserved_value_vars
140141
vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
141142
print("Transforming variables...", file=sys.stdout)
142-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
143-
result = jax.vmap(jax.vmap(jax_fn))(
144-
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
145-
)
143+
144+
if inference_backend == "pymc":
145+
# TODO: we need to remove JAX dependency as win32 users can now use Pathfinder with inference_backend="pymc".
146+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
147+
result = jax.vmap(jax.vmap(jax_fn))(
148+
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
149+
)
150+
elif inference_backend == "blackjax":
151+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
152+
result = jax.vmap(jax.vmap(jax_fn))(
153+
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
154+
)
155+
else:
156+
raise ValueError(f"Invalid inference_backend: {inference_backend}")
157+
146158
trace = {v.name: r for v, r in zip(vars_to_sample, result)}
147159
coords, dims = coords_and_dims_for_inferencedata(model)
148160
idata = az.from_dict(trace, dims=dims, coords=coords)
@@ -742,7 +754,6 @@ def fit_pathfinder(
742754
random_seed=random_seed,
743755
**pathfinder_kwargs,
744756
)
745-
746757
elif inference_backend == "blackjax":
747758
jitter_seed, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
748759
# TODO: extend initial points initialisation to blackjax
@@ -773,15 +784,15 @@ def fit_pathfinder(
773784
state=pathfinder_state,
774785
num_samples=num_draws,
775786
)
776-
777787
else:
778-
raise ValueError(f"Inference backend {inference_backend} not supported")
788+
raise ValueError(f"Invalid inference_backend: {inference_backend}")
779789

780790
print("Running pathfinder...", file=sys.stdout)
781791

782792
idata = convert_flat_trace_to_idata(
783793
pathfinder_samples,
784794
postprocessing_backend=postprocessing_backend,
795+
inference_backend=inference_backend,
785796
model=model,
786797
)
787798
return idata

0 commit comments

Comments
 (0)