@@ -122,6 +122,7 @@ def convert_flat_trace_to_idata(
122
122
samples ,
123
123
include_transformed = False ,
124
124
postprocessing_backend = "cpu" ,
125
+ inference_backend = "pymc" ,
125
126
model = None ,
126
127
):
127
128
model = modelcontext (model )
@@ -139,10 +140,21 @@ def convert_flat_trace_to_idata(
139
140
var_names = model .unobserved_value_vars
140
141
vars_to_sample = list (get_default_varnames (var_names , include_transformed = include_transformed ))
141
142
print ("Transforming variables..." , file = sys .stdout )
142
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
143
- result = jax .vmap (jax .vmap (jax_fn ))(
144
- * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
145
- )
143
+
144
+ if inference_backend == "pymc" :
145
+ # TODO: we need to remove JAX dependency as win32 users can now use Pathfinder with inference_backend="pymc".
146
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
147
+ result = jax .vmap (jax .vmap (jax_fn ))(
148
+ * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
149
+ )
150
+ elif inference_backend == "blackjax" :
151
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
152
+ result = jax .vmap (jax .vmap (jax_fn ))(
153
+ * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
154
+ )
155
+ else :
156
+ raise ValueError (f"Invalid inference_backend: { inference_backend } " )
157
+
146
158
trace = {v .name : r for v , r in zip (vars_to_sample , result )}
147
159
coords , dims = coords_and_dims_for_inferencedata (model )
148
160
idata = az .from_dict (trace , dims = dims , coords = coords )
@@ -742,7 +754,6 @@ def fit_pathfinder(
742
754
random_seed = random_seed ,
743
755
** pathfinder_kwargs ,
744
756
)
745
-
746
757
elif inference_backend == "blackjax" :
747
758
jitter_seed , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
748
759
# TODO: extend initial points initialisation to blackjax
@@ -773,15 +784,15 @@ def fit_pathfinder(
773
784
state = pathfinder_state ,
774
785
num_samples = num_draws ,
775
786
)
776
-
777
787
else :
778
- raise ValueError (f"Inference backend { inference_backend } not supported " )
788
+ raise ValueError (f"Invalid inference_backend: { inference_backend } " )
779
789
780
790
print ("Running pathfinder..." , file = sys .stdout )
781
791
782
792
idata = convert_flat_trace_to_idata (
783
793
pathfinder_samples ,
784
794
postprocessing_backend = postprocessing_backend ,
795
+ inference_backend = inference_backend ,
785
796
model = model ,
786
797
)
787
798
return idata
0 commit comments