Description
I run my standard test example using version 0.1.0 and get an error message from pm.sample, see below. Any advice of how to proceed?
TypeError Traceback (most recent call last)
Cell In[19], line 4
2 sampler = 'numpyro'
3 with pymc_model:
----> 4 idata = pm.sample(nuts_sampler=sampler, tune=500, draws=1000, chains=4, progressbar=True, target_accept=0.95)
6 # idate is an "inference data object", provided by the sampler. Sampling statistics are provided in idata.sample_stats. For more information
7 # about the sampling statistics, see
8 # https://www.pymc.io/projects/docs/en/v3/pymc-examples/examples/diagnostics_and_criticism/sampler-stats.html
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:691, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
687 if not isinstance(step, NUTS):
688 raise ValueError(
689 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
690 )
--> 691 return _sample_external_nuts(
692 sampler=nuts_sampler,
693 draws=draws,
694 tune=tune,
695 chains=chains,
696 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
697 random_seed=random_seed,
698 initvals=initvals,
699 model=model,
700 var_names=var_names,
701 progressbar=progressbar,
702 idata_kwargs=idata_kwargs,
703 nuts_sampler_kwargs=nuts_sampler_kwargs,
704 **kwargs,
705 )
707 if isinstance(step, list):
708 step = CompoundStep(step)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
348 elif sampler in ("numpyro", "blackjax"):
349 import pymc.sampling.jax as pymc_jax
--> 351 idata = pymc_jax.sample_jax_nuts(
352 draws=draws,
353 tune=tune,
354 chains=chains,
355 target_accept=target_accept,
356 random_seed=random_seed,
357 initvals=initvals,
358 model=model,
359 var_names=var_names,
360 progressbar=progressbar,
361 nuts_sampler=sampler,
362 idata_kwargs=idata_kwargs,
363 **nuts_sampler_kwargs,
364 )
365 return idata
367 else:
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:458, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
454 import numpyro
456 from numpyro.infer import MCMC, NUTS
--> 458 logp_fn = get_jaxified_logp(model, negative_logp=False)
460 nuts_kwargs.setdefault("adapt_step_size", True)
461 nuts_kwargs.setdefault("adapt_mass_matrix", True)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
151 if not negative_logp:
152 model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
155 def logp_fn_wrap(x):
156 return logp_fn(*x)[0]
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:128, in get_jaxified_graph(inputs, outputs)
122 def get_jaxified_graph(
123 inputs: list[TensorVariable] | None = None,
124 outputs: list[TensorVariable] | None = None,
125 ) -> list[TensorVariable]:
126 """Compile an PyTensor graph into an optimized JAX function"""
--> 128 graph = _replace_shared_variables(outputs) if outputs is not None else None
130 fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
131 # We need to add a Supervisor to the fgraph to be able to run the
132 # JAX sequential optimizer without warnings. We made sure there
133 # are no mutable input variables, so we only need to check for
134 # "destroyers". This should be automatically handled by PyTensor
135 # once aesara-devs/aesara#637 is fixed.
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:118, in _replace_shared_variables(graph)
111 raise ValueError(
112 "Graph contains shared variables with default_update which cannot "
113 "be safely replaced."
114 )
116 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 118 new_graph = clone_replace(graph, replace=replacements)
119 return new_graph
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/replace.py:85, in clone_replace(output, replace, **rebuild_kwds)
82 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
84 # TODO Explain why we call it twice ?!
---> 85 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
87 return outs
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:313, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
311 for v in outputs:
312 if isinstance(v, Variable):
--> 313 cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
314 cloned_outputs.append(cloned_v)
315 elif isinstance(v, Out):
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
[... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 189 (2 times)]
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:190, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
188 for i in owner.inputs:
189 clone_v_get_shared_updates(i, copy_inputs_over)
--> 190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
197 elif isinstance(v, SharedVariable):
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:1201, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
1197 new_op: "Op" | None = cast(Optional["Op"], clone_d.get(node.op))
1199 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1201 new_node = node.clone_with_new_inputs(
1202 cloned_inputs,
1203 # Only clone inner-graph Op
s when there isn't a cached clone (and
1204 # when clone_inner_graphs
is enabled)
1205 clone_inner_graph=clone_inner_graphs if new_op is None else False,
1206 **kwargs,
1207 )
1209 if new_op:
1210 # If we didn't clone the inner-graph Op
above, because
1211 # there was a cached version, set the cloned Apply
to use
1212 # the cached clone Op
1213 new_node.op = new_op
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:285, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
282 if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
283 new_op = new_op.clone() # type: ignore
--> 285 new_node = new_op.make_node(*new_inputs)
286 new_node.tag = copy(self.tag).update(new_node.tag)
287 else:
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:964, in Scan.make_node(self, *inputs)
960 argoffset = 0
961 for inner_seq, outer_seq in zip(
962 self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
963 ):
--> 964 check_broadcast(outer_seq, inner_seq)
965 new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
967 argoffset += len(self.outer_seqs(inputs))
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:179, in check_broadcast(v1, v2)
177 a1 = n + size - v1.type.ndim + 1
178 a2 = n + size - v2.type.ndim + 1
--> 179 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
TypeError: The broadcast pattern of the output of scan (Matrix(float64, shape=(144, 1))) is inconsistent with the one provided in output_info
(Vector(float64, shape=(?,))). The output on axis 0 is True
, but it is False
on axis 1 in output_info
. This can happen if one of the dimension is fixed to 1 in the input, while it is still variable in the output, or vice-verca. You have to make them consistent, e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.