Open
Description
Describe the issue:
When we call pymc.do
or pymc.observe
we might also want to change some coordinate values. At the moment, this can't be done.
Reproduceable code example:
import pymc as pm
with pm.Model(coords={"A": range(2)}) as m:
a = pm.Normal("a", dims="A")
b = pm.Deterministic("b", a + 1)
with pm.do(m, {"a": list(range(5))}) as m2:
samples = pm.draw(m2["b"], 10)
print(samples.shape) # prints (10, 5) as expected
samples = pm.sample_prior_predictive(var_names="b", draws=10) # errors
Error message:
ValueError Traceback (most recent call last)
Cell In[5], line 10
8 samples = pm.draw(m2["b"], 10)
9 print(samples.shape)
---> 10 samples = pm.sample_prior_predictive(var_names="b", draws=10)
File ~/repos/pymc/pymc/sampling/forward.py:466, in sample_prior_predictive(draws, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs, samples)
464 if idata_kwargs:
465 ikwargs.update(idata_kwargs)
--> 466 return pm.to_inference_data(prior=prior, **ikwargs)
File ~/repos/pymc/pymc/backends/arviz.py:532, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, log_prior, coords, dims, sample_dims, model, save_warmup, include_transformed)
517 if isinstance(trace, InferenceData):
518 return trace
520 return InferenceDataConverter(
521 trace=trace,
522 prior=prior,
523 posterior_predictive=posterior_predictive,
524 log_likelihood=log_likelihood,
525 log_prior=log_prior,
526 coords=coords,
527 dims=dims,
528 sample_dims=sample_dims,
529 model=model,
530 save_warmup=save_warmup,
531 include_transformed=include_transformed,
--> 532 ).to_inference_data()
File ~/repos/pymc/pymc/backends/arviz.py:434, in InferenceDataConverter.to_inference_data(self)
432 id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
433 else:
--> 434 id_dict["constant_data"] = self.constant_data_to_xarray()
435 idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
436 if self.log_likelihood:
File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:67, in requires.__call__.<locals>.wrapped(cls)
65 if all((getattr(cls, prop_i) is None for prop_i in prop)):
66 return None
---> 67 return func(cls)
File ~/repos/pymc/pymc/backends/arviz.py:398, in InferenceDataConverter.constant_data_to_xarray(self)
395 if not constant_data:
396 return None
--> 398 xarray_dataset = dict_to_dataset(
399 constant_data,
400 library=pymc,
401 coords=self.coords,
402 dims=self.dims,
403 default_dims=[],
404 )
406 # provisional handling of scalars in constant
407 # data to prevent promotion to rank 1
408 # in the future this will be handled by arviz
409 scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0]
File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:394, in pytree_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
391 except TypeError: # probably unsortable keys -- the function will still work if
392 pass # it is an honest dictionary.
--> 394 data_vars = {
395 key: numpy_to_data_array(
396 values,
397 var_name=key,
398 coords=coords,
399 dims=dims.get(key),
400 default_dims=default_dims,
401 index_origin=index_origin,
402 skip_event_dims=skip_event_dims,
403 )
404 for key, values in data.items()
405 }
406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:395, in <dictcomp>(.0)
391 except TypeError: # probably unsortable keys -- the function will still work if
392 pass # it is an honest dictionary.
394 data_vars = {
--> 395 key: numpy_to_data_array(
396 values,
397 var_name=key,
398 coords=coords,
399 dims=dims.get(key),
400 default_dims=default_dims,
401 index_origin=index_origin,
402 skip_event_dims=skip_event_dims,
403 )
404 for key, values in data.items()
405 }
406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:299, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
297 # filter coords based on the dims
298 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 299 return xr.DataArray(ary, coords=coords, dims=dims)
File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:455, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
453 data = _check_data_shape(data, coords, dims)
454 data = as_compatible_data(data)
--> 455 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
456 variable = Variable(dims, data, attrs, fastpath=True)
458 if not isinstance(coords, Coordinates):
File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:194, in _infer_coords_and_dims(shape, coords, dims)
191 var.dims = (dim,)
192 new_coords[dim] = var.to_index_variable()
--> 194 _check_coords_dims(shape, new_coords, dims_tuple)
196 return new_coords, dims_tuple
File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:128, in _check_coords_dims(shape, coords, dim)
126 for d, s in v.sizes.items():
127 if s != sizes[d]:
--> 128 raise ValueError(
129 f"conflicting sizes for dimension {d!r}: "
130 f"length {sizes[d]} on the data but length {s} on "
131 f"coordinate {k!r}"
132 )
ValueError: conflicting sizes for dimension 'A': length 5 on the data but length 2 on coordinate 'A'
PyMC version information:
main
Context for the issue:
No response