Skip to content

Accept new coords in model transforms #7549

Open
@lucianopaz

Description

@lucianopaz

Describe the issue:

When we call pymc.do or pymc.observe we might also want to change some coordinate values. At the moment, this can't be done.

Reproduceable code example:

import pymc as pm

with pm.Model(coords={"A": range(2)}) as m:
    a = pm.Normal("a", dims="A")
    b = pm.Deterministic("b", a + 1)

with pm.do(m, {"a": list(range(5))}) as m2:
    samples = pm.draw(m2["b"], 10)
    print(samples.shape)  # prints (10, 5) as expected
    samples = pm.sample_prior_predictive(var_names="b", draws=10)  # errors

Error message:

ValueError                                Traceback (most recent call last)
Cell In[5], line 10
      8 samples = pm.draw(m2["b"], 10)
      9 print(samples.shape)
---> 10 samples = pm.sample_prior_predictive(var_names="b", draws=10)

File ~/repos/pymc/pymc/sampling/forward.py:466, in sample_prior_predictive(draws, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs, samples)
    464 if idata_kwargs:
    465     ikwargs.update(idata_kwargs)
--> 466 return pm.to_inference_data(prior=prior, **ikwargs)

File ~/repos/pymc/pymc/backends/arviz.py:532, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, log_prior, coords, dims, sample_dims, model, save_warmup, include_transformed)
    517 if isinstance(trace, InferenceData):
    518     return trace
    520 return InferenceDataConverter(
    521     trace=trace,
    522     prior=prior,
    523     posterior_predictive=posterior_predictive,
    524     log_likelihood=log_likelihood,
    525     log_prior=log_prior,
    526     coords=coords,
    527     dims=dims,
    528     sample_dims=sample_dims,
    529     model=model,
    530     save_warmup=save_warmup,
    531     include_transformed=include_transformed,
--> 532 ).to_inference_data()

File ~/repos/pymc/pymc/backends/arviz.py:434, in InferenceDataConverter.to_inference_data(self)
    432     id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
    433 else:
--> 434     id_dict["constant_data"] = self.constant_data_to_xarray()
    435 idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
    436 if self.log_likelihood:

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:67, in requires.__call__.<locals>.wrapped(cls)
     65     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     66         return None
---> 67 return func(cls)

File ~/repos/pymc/pymc/backends/arviz.py:398, in InferenceDataConverter.constant_data_to_xarray(self)
    395 if not constant_data:
    396     return None
--> 398 xarray_dataset = dict_to_dataset(
    399     constant_data,
    400     library=pymc,
    401     coords=self.coords,
    402     dims=self.dims,
    403     default_dims=[],
    404 )
    406 # provisional handling of scalars in constant
    407 # data to prevent promotion to rank 1
    408 # in the future this will be handled by arviz
    409 scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0]

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:394, in pytree_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    391 except TypeError:  # probably unsortable keys -- the function will still work if
    392     pass  # it is an honest dictionary.
--> 394 data_vars = {
    395     key: numpy_to_data_array(
    396         values,
    397         var_name=key,
    398         coords=coords,
    399         dims=dims.get(key),
    400         default_dims=default_dims,
    401         index_origin=index_origin,
    402         skip_event_dims=skip_event_dims,
    403     )
    404     for key, values in data.items()
    405 }
    406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:395, in <dictcomp>(.0)
    391 except TypeError:  # probably unsortable keys -- the function will still work if
    392     pass  # it is an honest dictionary.
    394 data_vars = {
--> 395     key: numpy_to_data_array(
    396         values,
    397         var_name=key,
    398         coords=coords,
    399         dims=dims.get(key),
    400         default_dims=default_dims,
    401         index_origin=index_origin,
    402         skip_event_dims=skip_event_dims,
    403     )
    404     for key, values in data.items()
    405 }
    406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:299, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    297 # filter coords based on the dims
    298 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 299 return xr.DataArray(ary, coords=coords, dims=dims)

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:455, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
    453 data = _check_data_shape(data, coords, dims)
    454 data = as_compatible_data(data)
--> 455 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
    456 variable = Variable(dims, data, attrs, fastpath=True)
    458 if not isinstance(coords, Coordinates):

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:194, in _infer_coords_and_dims(shape, coords, dims)
    191             var.dims = (dim,)
    192             new_coords[dim] = var.to_index_variable()
--> 194 _check_coords_dims(shape, new_coords, dims_tuple)
    196 return new_coords, dims_tuple

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:128, in _check_coords_dims(shape, coords, dim)
    126 for d, s in v.sizes.items():
    127     if s != sizes[d]:
--> 128         raise ValueError(
    129             f"conflicting sizes for dimension {d!r}: "
    130             f"length {sizes[d]} on the data but length {s} on "
    131             f"coordinate {k!r}"
    132         )

ValueError: conflicting sizes for dimension 'A': length 5 on the data but length 2 on coordinate 'A'

PyMC version information:

main

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions