Skip to content

Issues with dimensions of unobserved components in versions 0.1.0 and 0.1.1 #338

Closed
@rklees

Description

@rklees

I run my standard test example using version 0.1.0 and get an error message from pm.sample, see below. Any advice of how to proceed?

TypeError Traceback (most recent call last)
Cell In[19], line 4
2 sampler = 'numpyro'
3 with pymc_model:
----> 4 idata = pm.sample(nuts_sampler=sampler, tune=500, draws=1000, chains=4, progressbar=True, target_accept=0.95)
6 # idate is an "inference data object", provided by the sampler. Sampling statistics are provided in idata.sample_stats. For more information
7 # about the sampling statistics, see
8 # https://www.pymc.io/projects/docs/en/v3/pymc-examples/examples/diagnostics_and_criticism/sampler-stats.html

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:691, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
687 if not isinstance(step, NUTS):
688 raise ValueError(
689 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
690 )
--> 691 return _sample_external_nuts(
692 sampler=nuts_sampler,
693 draws=draws,
694 tune=tune,
695 chains=chains,
696 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
697 random_seed=random_seed,
698 initvals=initvals,
699 model=model,
700 var_names=var_names,
701 progressbar=progressbar,
702 idata_kwargs=idata_kwargs,
703 nuts_sampler_kwargs=nuts_sampler_kwargs,
704 **kwargs,
705 )
707 if isinstance(step, list):
708 step = CompoundStep(step)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
348 elif sampler in ("numpyro", "blackjax"):
349 import pymc.sampling.jax as pymc_jax
--> 351 idata = pymc_jax.sample_jax_nuts(
352 draws=draws,
353 tune=tune,
354 chains=chains,
355 target_accept=target_accept,
356 random_seed=random_seed,
357 initvals=initvals,
358 model=model,
359 var_names=var_names,
360 progressbar=progressbar,
361 nuts_sampler=sampler,
362 idata_kwargs=idata_kwargs,
363 **nuts_sampler_kwargs,
364 )
365 return idata
367 else:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:458, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
454 import numpyro
456 from numpyro.infer import MCMC, NUTS
--> 458 logp_fn = get_jaxified_logp(model, negative_logp=False)
460 nuts_kwargs.setdefault("adapt_step_size", True)
461 nuts_kwargs.setdefault("adapt_mass_matrix", True)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
151 if not negative_logp:
152 model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
155 def logp_fn_wrap(x):
156 return logp_fn(*x)[0]

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:128, in get_jaxified_graph(inputs, outputs)
122 def get_jaxified_graph(
123 inputs: list[TensorVariable] | None = None,
124 outputs: list[TensorVariable] | None = None,
125 ) -> list[TensorVariable]:
126 """Compile an PyTensor graph into an optimized JAX function"""
--> 128 graph = _replace_shared_variables(outputs) if outputs is not None else None
130 fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
131 # We need to add a Supervisor to the fgraph to be able to run the
132 # JAX sequential optimizer without warnings. We made sure there
133 # are no mutable input variables, so we only need to check for
134 # "destroyers". This should be automatically handled by PyTensor
135 # once aesara-devs/aesara#637 is fixed.

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:118, in _replace_shared_variables(graph)
111 raise ValueError(
112 "Graph contains shared variables with default_update which cannot "
113 "be safely replaced."
114 )
116 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 118 new_graph = clone_replace(graph, replace=replacements)
119 return new_graph

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/replace.py:85, in clone_replace(output, replace, **rebuild_kwds)
82 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
84 # TODO Explain why we call it twice ?!
---> 85 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
87 return outs

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:313, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
311 for v in outputs:
312 if isinstance(v, Variable):
--> 313 cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
314 cloned_outputs.append(cloned_v)
315 elif isinstance(v, Out):

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

[... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 189 (2 times)]

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:190, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
188 for i in owner.inputs:
189 clone_v_get_shared_updates(i, copy_inputs_over)
--> 190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
197 elif isinstance(v, SharedVariable):

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:1201, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
1197 new_op: "Op" | None = cast(Optional["Op"], clone_d.get(node.op))
1199 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1201 new_node = node.clone_with_new_inputs(
1202 cloned_inputs,
1203 # Only clone inner-graph Ops when there isn't a cached clone (and
1204 # when clone_inner_graphs is enabled)
1205 clone_inner_graph=clone_inner_graphs if new_op is None else False,
1206 **kwargs,
1207 )
1209 if new_op:
1210 # If we didn't clone the inner-graph Op above, because
1211 # there was a cached version, set the cloned Apply to use
1212 # the cached clone Op
1213 new_node.op = new_op

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:285, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
282 if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
283 new_op = new_op.clone() # type: ignore
--> 285 new_node = new_op.make_node(*new_inputs)
286 new_node.tag = copy(self.tag).update(new_node.tag)
287 else:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:964, in Scan.make_node(self, *inputs)
960 argoffset = 0
961 for inner_seq, outer_seq in zip(
962 self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
963 ):
--> 964 check_broadcast(outer_seq, inner_seq)
965 new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
967 argoffset += len(self.outer_seqs(inputs))

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:179, in check_broadcast(v1, v2)
177 a1 = n + size - v1.type.ndim + 1
178 a2 = n + size - v2.type.ndim + 1
--> 179 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))

TypeError: The broadcast pattern of the output of scan (Matrix(float64, shape=(144, 1))) is inconsistent with the one provided in output_info (Vector(float64, shape=(?,))). The output on axis 0 is True, but it is False on axis 1 in output_info. This can happen if one of the dimension is fixed to 1 in the input, while it is still variable in the output, or vice-verca. You have to make them consistent, e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions