Skip to content

Commit 5003508

Browse files
committed
Add ZarrTrace compatibility checks
1 parent 2a9f318 commit 5003508

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

pymc/backends/zarr.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,3 +1163,124 @@ def from_store(
11631163
self.coords = coords
11641164
self.vars_to_dims = vars_to_dims
11651165
return self
1166+
1167+
def assert_model_and_step_are_compatible(
1168+
self,
1169+
step: BlockedStep | CompoundStep,
1170+
model: Model,
1171+
vars: list[TensorVariable] | None = None,
1172+
):
1173+
zarr_groups = set(self.root.group_keys())
1174+
arrays_ = itertools.chain(
1175+
self.posterior.arrays(),
1176+
self.constant_data.arrays() if "constant_data" in zarr_groups else [],
1177+
self.observed_data.arrays() if "observed_data" in zarr_groups else [],
1178+
)
1179+
if self.include_transformed:
1180+
arrays_ = itertools.chain(arrays_, self.unconstrained_posterior.arrays())
1181+
arrays = list(arrays_)
1182+
zarr_varnames = []
1183+
zarr_coords = {}
1184+
zarr_vars_to_dims = {}
1185+
zarr_deterministics = []
1186+
zarr_free_vars = []
1187+
for name, array in arrays:
1188+
dims = array.attrs["_ARRAY_DIMENSIONS"]
1189+
if dims[:2] == ["chain", "draw"]:
1190+
# Random Variable
1191+
zarr_vars_to_dims[name] = dims[2:]
1192+
zarr_varnames.append(name)
1193+
if array.attrs["kind"] == "freeRV":
1194+
zarr_free_vars.append(name)
1195+
else:
1196+
zarr_deterministics.append(name)
1197+
elif len(dims) == 1 and name == dims[0]:
1198+
# Coordinate
1199+
if name not in ["chain", "draw"]:
1200+
zarr_coords[name] = np.asarray(array)
1201+
else:
1202+
# Constant data or observation
1203+
zarr_vars_to_dims[name] = dims
1204+
zarr_constant_data = (
1205+
[name for name in self.constant_data.array_keys() if name not in zarr_coords]
1206+
if "constant_data" in zarr_groups
1207+
else []
1208+
)
1209+
zarr_observed_data = (
1210+
[name for name in self.observed_data.array_keys() if name not in zarr_coords]
1211+
if "observed_data" in zarr_groups
1212+
else []
1213+
)
1214+
autogenerated_dims = {dim for dim in zarr_coords if re.search(r"_dim_\d+$", dim)}
1215+
1216+
# Check deterministics, free RVs and transformed RVs
1217+
_, var_names = self.parse_varnames(model, vars)
1218+
assert set(var_names) == set(zarr_free_vars + zarr_deterministics), (
1219+
"The model deterministics and random variables given the sampled var_names "
1220+
"do not match with the stored deterministics variables in the trace."
1221+
)
1222+
for name, array in arrays:
1223+
if name not in zarr_free_vars or name not in zarr_deterministics:
1224+
continue
1225+
model_var = model[name]
1226+
assert np.dtype(model_var.dtype) == np.dtype(array.dtype), (
1227+
"The model deterministics and random variables given the sampled "
1228+
"var_names do not match with the stored deterministics variables in "
1229+
"the trace."
1230+
)
1231+
1232+
# Check coordinates
1233+
assert (set(zarr_coords) - set(autogenerated_dims)) == set(model.coords) and all(
1234+
np.array_equal(np.asarray(zarr_coords[dim]), np.asarray(coord))
1235+
for dim, coord in model.coords.items()
1236+
), "Model coordinates don't match the coordinates stored in the trace"
1237+
vars_to_explicit_dims = {}
1238+
for name, dims in zarr_vars_to_dims.items():
1239+
if len(dims) == 0 or all(dim in autogenerated_dims for dim in dims):
1240+
# These variables wont be included in the named_vars_to_dims
1241+
continue
1242+
vars_to_explicit_dims[name] = [
1243+
dim if dim not in autogenerated_dims else None for dim in dims
1244+
]
1245+
assert set(vars_to_explicit_dims) == set(model.named_vars_to_dims) and all(
1246+
vars_to_explicit_dims[name] == list(dims)
1247+
for name, dims in model.named_vars_to_dims.items()
1248+
), "Some model variables have different dimensions than those stored in the trace."
1249+
1250+
# Check constant data
1251+
model_constant_data = find_constants(model)
1252+
assert set(zarr_constant_data) == set(model_constant_data), (
1253+
"The model constant data does not match with the stored constant data"
1254+
)
1255+
for name, model_data in model_constant_data.items():
1256+
assert np.array_equal(self.constant_data[name], model_data, equal_nan=True), (
1257+
"The model constant data does not match with the stored constant data"
1258+
)
1259+
1260+
# Check observed data
1261+
model_observed_data = find_observations(model)
1262+
assert set(zarr_observed_data) == set(model_observed_data), (
1263+
"The model observed data does not match with the stored observed data"
1264+
)
1265+
for name, model_data in model_observed_data.items():
1266+
assert np.array_equal(self.observed_data[name], model_data, equal_nan=True), (
1267+
"The model observed data does not match with the stored observed data"
1268+
)
1269+
1270+
# Check sample stats given the step method
1271+
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
1272+
[step] if isinstance(step, BlockedStep) else step.methods
1273+
)
1274+
assert (set(stats_dtypes_shapes) | {"chain", "draw"}) == set(
1275+
self.sample_stats.array_keys()
1276+
), "The step method sample stats do not match the ones stored in the trace."
1277+
for name, array in self.sample_stats.arrays():
1278+
if name in ("chain", "draw"):
1279+
continue
1280+
assert np.dtype(stats_dtypes_shapes[name][0]) == np.dtype(array.dtype), (
1281+
"The step method sample stats do not match the ones stored in the trace."
1282+
)
1283+
1284+
assert step.sampling_state.is_compatible(self._sampling_state.sampling_state[0]), (
1285+
"The step method sampling state class is incompatible with what's stored in the trace."
1286+
)

0 commit comments

Comments
 (0)