Skip to content

Commit 5b69bb8

Browse files
committed
Resume sampling from existing ZarrTrace
1 parent 5003508 commit 5b69bb8

File tree

5 files changed

+380
-71
lines changed

5 files changed

+380
-71
lines changed

pymc/backends/__init__.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
7373
from pymc.backends.base import BaseTrace, IBaseTrace
7474
from pymc.backends.ndarray import NDArray
75-
from pymc.backends.zarr import ZarrTrace
75+
from pymc.backends.zarr import TraceAlreadyInitialized, ZarrTrace
7676
from pymc.blocking import PointType
7777
from pymc.model import Model
7878
from pymc.step_methods.compound import BlockedStep, CompoundStep
@@ -132,15 +132,41 @@ def init_traces(
132132
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
133133
"""Initialize a trace recorder for each chain."""
134134
if isinstance(backend, ZarrTrace):
135-
backend.init_trace(
136-
chains=chains,
137-
draws=expected_length - tune,
138-
tune=tune,
139-
step=step,
140-
model=model,
141-
vars=trace_vars,
142-
test_point=initial_point,
143-
)
135+
try:
136+
backend.init_trace(
137+
chains=chains,
138+
draws=expected_length - tune,
139+
tune=tune,
140+
step=step,
141+
model=model,
142+
vars=trace_vars,
143+
test_point=initial_point,
144+
)
145+
except TraceAlreadyInitialized:
146+
# Trace has already been initialized. We need to make sure that the
147+
# tracked variable names and the number of chains match, and then resize
148+
# the zarr groups to the desired number of draws and tune.
149+
backend.assert_model_and_step_are_compatible(
150+
step=step,
151+
model=model,
152+
vars=trace_vars,
153+
)
154+
assert backend.posterior.chain.size == chains, (
155+
f"The requested number of chains {chains} does not match the number "
156+
f"of chains stored in the trace ({backend.posterior.chain.size})."
157+
)
158+
vars, var_names = backend.parse_varnames(model=model, vars=trace_vars)
159+
backend.link_model_and_step(
160+
chains=chains,
161+
draws=expected_length - tune,
162+
tune=tune,
163+
step=step,
164+
model=model,
165+
vars=vars,
166+
var_names=var_names,
167+
test_point=initial_point,
168+
)
169+
backend.resize(tune=tune, draws=expected_length - tune)
144170
return None, backend.straces
145171
if HAS_MCB and isinstance(backend, Backend):
146172
return init_chain_adapters(

pymc/sampling/mcmc.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def joined_blas_limiter():
767767
UserWarning,
768768
)
769769
rngs = get_random_generator(random_seed).spawn(chains)
770-
random_seed_list = [rng.integers(2**30) for rng in rngs]
770+
random_seed_list: list[int] = [int(rng.integers(2**30)) for rng in rngs]
771771

772772
if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace):
773773
warnings.warn(
@@ -993,11 +993,8 @@ def _sample_return(
993993
Final step of `pm.sampler`.
994994
"""
995995
if isinstance(traces, ZarrTrace):
996-
# Split warmup from posterior samples
997-
traces.split_warmup_groups()
998-
999996
# Set sampling time
1000-
traces.sampling_time = t_sampling
997+
traces.sampling_time = traces.sampling_time + t_sampling
1001998

1002999
# Compute number of actual draws per chain
10031000
total_draws_per_chain = traces._sampling_state.draw_idx[:]
@@ -1226,7 +1223,7 @@ def _sample(
12261223
callback=callback,
12271224
)
12281225
try:
1229-
for it, stats in enumerate(sampling_gen):
1226+
for it, stats in sampling_gen:
12301227
progress_manager.update(
12311228
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
12321229
)
@@ -1251,7 +1248,7 @@ def _iter_sample(
12511248
rng: np.random.Generator,
12521249
model: Model | None = None,
12531250
callback: SamplingIteratorCallback | None = None,
1254-
) -> Iterator[list[dict[str, Any]]]:
1251+
) -> Iterator[tuple[int, list[dict[str, Any]]]]:
12551252
"""Sample one chain with a generator (singleprocess).
12561253
12571254
Parameters
@@ -1285,14 +1282,29 @@ def _iter_sample(
12851282
step.set_rng(rng)
12861283

12871284
point = start
1288-
if isinstance(trace, ZarrChain):
1289-
trace.link_stepper(step)
1285+
initial_draw_idx = 0
1286+
step.tune = bool(tune)
1287+
if hasattr(step, "reset_tuning"):
1288+
step.reset_tuning()
1289+
trace.link_stepper(step)
1290+
stored_draw_idx, stored_sampling_state = trace.get_stored_draw_and_state()
1291+
if stored_draw_idx > 0:
1292+
if stored_sampling_state is not None:
1293+
step.sampling_state = stored_sampling_state
1294+
else:
1295+
raise RuntimeError(
1296+
"Cannot use the supplied ZarrTrace to restart sampling because "
1297+
"it has no sampling_state information stored. You will have to "
1298+
"resample from scratch."
1299+
)
1300+
initial_draw_idx = stored_draw_idx
1301+
point = trace.get_mcmc_point()
1302+
else:
1303+
# Store initial point in trace
1304+
trace.set_mcmc_point(point)
12901305

12911306
try:
1292-
step.tune = bool(tune)
1293-
if hasattr(step, "reset_tuning"):
1294-
step.reset_tuning()
1295-
for i in range(draws):
1307+
for i in range(initial_draw_idx, draws):
12961308
if i == 0 and hasattr(step, "iter_count"):
12971309
step.iter_count = 0
12981310
if i == tune:
@@ -1308,17 +1320,15 @@ def _iter_sample(
13081320
draw=Draw(chain, i == draws, i, i < tune, stats, point),
13091321
)
13101322

1311-
yield stats
1323+
yield i, stats
13121324

13131325
except (KeyboardInterrupt, BaseException):
1314-
if isinstance(trace, ZarrChain):
1315-
trace.record_sampling_state(step=step)
1326+
trace.record_sampling_state(step=step)
13161327
trace.close()
13171328
raise
13181329

13191330
else:
1320-
if isinstance(trace, ZarrChain):
1321-
trace.record_sampling_state(step=step)
1331+
trace.record_sampling_state(step=step)
13221332
trace.close()
13231333

13241334

@@ -1377,7 +1387,6 @@ def _mp_sample(
13771387

13781388
# We did draws += tune in pm.sample
13791389
draws -= tune
1380-
zarr_chains: list[ZarrChain] | None = None
13811390
zarr_recording = False
13821391
if all(isinstance(trace, ZarrChain) for trace in traces):
13831392
if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore):
@@ -1388,7 +1397,6 @@ def _mp_sample(
13881397
"DirectoryStore or ZipStore"
13891398
)
13901399
else:
1391-
zarr_chains = cast(list[ZarrChain], traces)
13921400
zarr_recording = True
13931401

13941402
sampler = ps.ParallelSampler(
@@ -1403,7 +1411,9 @@ def _mp_sample(
14031411
progressbar_theme=progressbar_theme,
14041412
blas_cores=blas_cores,
14051413
mp_ctx=mp_ctx,
1406-
zarr_chains=zarr_chains,
1414+
# We only need to pass the traces when zarr_recording is happening because
1415+
# it's the only backend that can resume sampling
1416+
traces=traces if zarr_recording else None,
14071417
)
14081418
try:
14091419
try:

pymc/sampling/parallel.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from rich.theme import Theme
3131
from threadpoolctl import threadpool_limits
3232

33+
from pymc.backends.base import IBaseTrace
3334
from pymc.backends.zarr import ZarrChain
3435
from pymc.blocking import DictToArrayBijection
3536
from pymc.exceptions import SamplingError
@@ -105,8 +106,9 @@ def __init__(
105106
rng_state: RandomGeneratorState,
106107
blas_cores,
107108
chain: int,
108-
zarr_chains: list[ZarrChain] | bytes | None = None,
109-
zarr_chains_is_pickled: bool = False,
109+
traces: list[IBaseTrace] | bytes | None = None,
110+
traces_is_pickled: bool = False,
111+
zarr_recording: bool = False,
110112
):
111113
# Because of https://github.com/numpy/numpy/issues/27727, we can't send
112114
# the rng instance to the child process because pickling (copying) looses
@@ -117,13 +119,12 @@ def __init__(
117119
self._step_method = step_method
118120
self._step_method_is_pickled = step_method_is_pickled
119121
self.chain = chain
120-
self._zarr_recording = False
121-
self._zarr_chain: ZarrChain | None = None
122-
if zarr_chains_is_pickled:
123-
self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain]
124-
elif zarr_chains is not None:
125-
self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain]
126-
self._zarr_recording = self._zarr_chain is not None
122+
self._zarr_recording = zarr_recording
123+
self._trace: IBaseTrace | None = None
124+
if traces_is_pickled:
125+
self._trace = cloudpickle.loads(traces)[self.chain]
126+
elif traces is not None:
127+
self._trace = cast(list[IBaseTrace], traces)[self.chain]
127128

128129
self._shared_point = shared_point
129130
self._rng = rng
@@ -165,7 +166,7 @@ def run(self):
165166

166167
def _link_step_to_zarrchain(self):
167168
if self._zarr_recording:
168-
self._zarr_chain.link_stepper(self._step_method)
169+
self._trace.link_stepper(self._step_method)
169170

170171
def _wait_for_abortion(self):
171172
while True:
@@ -194,6 +195,24 @@ def _start_loop(self):
194195

195196
draw = 0
196197
tuning = True
198+
if self._zarr_recording:
199+
trace = self._trace
200+
stored_draw_idx = trace._sampling_state.draw_idx[self.chain]
201+
stored_sampling_state = trace._sampling_state.sampling_state[self.chain]
202+
if stored_draw_idx > 0:
203+
if stored_sampling_state is not None:
204+
self._step_method.sampling_state = stored_sampling_state
205+
else:
206+
raise RuntimeError(
207+
"Cannot use the supplied ZarrTrace to restart sampling because "
208+
"it has no sampling_state information stored. You will have to "
209+
"resample from scratch."
210+
)
211+
draw = stored_draw_idx
212+
self._write_point(trace.get_mcmc_point())
213+
else:
214+
# Store starting point in trace's mcmc_point
215+
trace.set_mcmc_point(self._point)
197216

198217
msg = self._recv_msg()
199218
if msg[0] == "abort":
@@ -220,7 +239,7 @@ def _start_loop(self):
220239
raise KeyboardInterrupt()
221240
elif msg[0] == "write_next":
222241
if zarr_recording:
223-
self._zarr_chain.record(point, stats)
242+
self._trace.record(point, stats)
224243
self._write_point(point)
225244
is_last = draw + 1 == self._draws + self._tune
226245
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
@@ -247,8 +266,9 @@ def __init__(
247266
start: dict[str, np.ndarray],
248267
blas_cores,
249268
mp_ctx,
250-
zarr_chains: list[ZarrChain] | None = None,
251-
zarr_chains_pickled: bytes | None = None,
269+
traces: Sequence[IBaseTrace] | None = None,
270+
traces_pickled: bytes | None = None,
271+
zarr_recording: bool = False,
252272
):
253273
self.chain = chain
254274
process_name = f"worker_chain_{chain}"
@@ -271,15 +291,15 @@ def __init__(
271291
self._readable = True
272292
self._num_samples = 0
273293

274-
zarr_chains_send: list[ZarrChain] | bytes | None = None
275-
if zarr_chains_pickled is not None:
276-
zarr_chains_send = zarr_chains_pickled
277-
elif zarr_chains is not None:
294+
traces_send: Sequence[IBaseTrace] | bytes | None = None
295+
if traces_pickled is not None:
296+
traces_send = traces_pickled
297+
elif traces is not None:
278298
if mp_ctx.get_start_method() == "spawn":
279299
raise ValueError(
280-
"please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
300+
"please provide a pre-pickled traces when multiprocessing start method is 'spawn'"
281301
)
282-
zarr_chains_send = zarr_chains
302+
traces_send = traces
283303

284304
if step_method_pickled is not None:
285305
step_method_send = step_method_pickled
@@ -305,8 +325,9 @@ def __init__(
305325
get_state_from_generator(rng),
306326
blas_cores,
307327
self.chain,
308-
zarr_chains_send,
309-
zarr_chains_pickled is not None,
328+
traces_send,
329+
traces_pickled is not None,
330+
zarr_recording,
310331
),
311332
)
312333
self._process.start()
@@ -429,7 +450,7 @@ def __init__(
429450
progressbar_theme: Theme | None = default_progress_theme,
430451
blas_cores: int | None = None,
431452
mp_ctx=None,
432-
zarr_chains: list[ZarrChain] | None = None,
453+
traces: Sequence[IBaseTrace] | None = None,
433454
):
434455
if any(len(arg) != chains for arg in [rngs, start_points]):
435456
raise ValueError(f"Number of rngs and start_points must be {chains}.")
@@ -450,15 +471,14 @@ def __init__(
450471
mp_ctx = multiprocessing.get_context(mp_ctx)
451472

452473
step_method_pickled = None
453-
zarr_chains_pickled = None
454-
self.zarr_recording = False
455-
if zarr_chains is not None:
456-
assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains)
457-
self.zarr_recording = True
474+
traces_pickled = None
475+
self.zarr_recording = traces is not None and all(
476+
isinstance(trace, ZarrChain) for trace in traces
477+
)
458478
if mp_ctx.get_start_method() != "fork":
459479
step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
460-
if zarr_chains is not None:
461-
zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1)
480+
if traces is not None:
481+
traces_pickled = cloudpickle.dumps(traces, protocol=-1)
462482

463483
self._samplers = [
464484
ProcessAdapter(
@@ -471,8 +491,9 @@ def __init__(
471491
start,
472492
blas_cores,
473493
mp_ctx,
474-
zarr_chains=zarr_chains,
475-
zarr_chains_pickled=zarr_chains_pickled,
494+
traces=traces,
495+
traces_pickled=traces_pickled,
496+
zarr_recording=self.zarr_recording,
476497
)
477498
for chain, rng, start in zip(range(chains), rngs, start_points)
478499
]

0 commit comments

Comments
 (0)