Skip to content

Commit 2a9f318

Browse files
committed
Add ZarrTrace.from_store
1 parent c3e62dc commit 2a9f318

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

pymc/backends/zarr.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,74 @@ def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData:
10921092
data.attrs = make_attrs(attrs=attrs, library=pymc)
10931093
groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data
10941094
return az.InferenceData(**groups)
1095+
1096+
@classmethod
1097+
def from_store(
1098+
cls: type["ZarrTrace"],
1099+
store: BaseStore | MutableMapping,
1100+
synchronizer: Synchronizer | None = None,
1101+
) -> "ZarrTrace":
1102+
if not _zarr_available:
1103+
raise RuntimeError("You must install zarr to be able to create ZarrTrace instances")
1104+
self: ZarrTrace = object.__new__(cls)
1105+
self.root = zarr.group(
1106+
store=store,
1107+
overwrite=False,
1108+
synchronizer=synchronizer,
1109+
)
1110+
self.synchronizer = synchronizer
1111+
self.compressor = default_compressor
1112+
1113+
groups = set(self.root.group_keys())
1114+
assert groups >= {
1115+
"posterior",
1116+
"sample_stats",
1117+
"warmup_posterior",
1118+
"warmup_sample_stats",
1119+
"constant_data",
1120+
"observed_data",
1121+
"_sampling_state",
1122+
}
1123+
1124+
if "posterior" in groups:
1125+
for _, array in self.posterior.arrays():
1126+
dims = array.attrs.get("_ARRAY_DIMENSIONS", [])
1127+
if len(dims) >= 2 and dims[1] == "draw":
1128+
draws_per_chunk = int(array.chunks[1])
1129+
break
1130+
else:
1131+
draws_per_chunk = 1
1132+
1133+
self.draws_per_chunk = int(draws_per_chunk)
1134+
assert self.draws_per_chunk >= 1
1135+
1136+
self.include_transformed = "unconstrained_posterior" in groups
1137+
arrays = itertools.chain(
1138+
self.posterior.arrays(),
1139+
self.constant_data.arrays(),
1140+
self.observed_data.arrays(),
1141+
)
1142+
if self.include_transformed:
1143+
arrays = itertools.chain(arrays, self.unconstrained_posterior.arrays())
1144+
varnames = []
1145+
coords = {}
1146+
vars_to_dims = {}
1147+
for name, array in arrays:
1148+
dims = array.attrs["_ARRAY_DIMENSIONS"]
1149+
if dims[:2] == ["chain", "draw"]:
1150+
# Random Variable
1151+
vars_to_dims[name] = dims[2:]
1152+
varnames.append(name)
1153+
elif len(dims) == 1 and name == dims[0]:
1154+
# Coordinate
1155+
# We store all model coordinates, which means we have to exclude chain
1156+
# and draw
1157+
if name not in ["chain", "draw"]:
1158+
coords[name] = np.asarray(array)
1159+
else:
1160+
# Constant data or observation
1161+
vars_to_dims[name] = dims
1162+
self.varnames = varnames
1163+
self.coords = coords
1164+
self.vars_to_dims = vars_to_dims
1165+
return self

tests/backends/test_zarr.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,3 +618,41 @@ def test_sampling_consistency(
618618
sequential_trace._sampling_state.sampling_state[chain],
619619
)
620620
xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior)
621+
622+
623+
def test_from_store(populated_trace):
624+
trace, total_steps, tune, draws = populated_trace
625+
loaded_trace = ZarrTrace.from_store(
626+
trace.root.store,
627+
)
628+
assert loaded_trace.is_root_populated and not loaded_trace._is_base_setup
629+
assert trace.draws_per_chunk == loaded_trace.draws_per_chunk
630+
assert trace.include_transformed == loaded_trace.include_transformed
631+
assert set(trace.varnames) == set(loaded_trace.varnames)
632+
assert set(trace.coords) == set(loaded_trace.coords) and (
633+
all(
634+
np.array_equal(np.asarray(coord), np.asarray(loaded_trace.coords[dim]))
635+
for dim, coord in trace.coords.items()
636+
)
637+
)
638+
assert trace.vars_to_dims == loaded_trace.vars_to_dims
639+
640+
assert not hasattr(loaded_trace, "straces")
641+
assert set(trace.root.group_keys()) == set(loaded_trace.root.group_keys())
642+
for group_name, group in trace.root.groups():
643+
loaded_group = loaded_trace.root[group_name]
644+
if group_name == "_sampling_state":
645+
assert all(
646+
equal_sampling_states(this, other) if this is not None else this is other
647+
for this, other in zip(group.sampling_state[:], loaded_group.sampling_state[:])
648+
)
649+
np.testing.assert_array_equal(group.draw_idx, loaded_group.draw_idx)
650+
assert trace.tuning_steps == loaded_trace.tuning_steps
651+
assert trace.draws == loaded_trace.draws
652+
assert trace.sampling_time == loaded_trace.sampling_time
653+
else:
654+
assert set(group.array_keys()) == set(loaded_group.array_keys())
655+
for name, array in group.arrays():
656+
loaded_array = loaded_group[name]
657+
assert dict(array.attrs) == dict(loaded_array.attrs)
658+
np.testing.assert_array_equal(np.asarray(array), np.asarray(loaded_array))

0 commit comments

Comments
 (0)