30
30
from rich .theme import Theme
31
31
from threadpoolctl import threadpool_limits
32
32
33
+ from pymc .backends .base import IBaseTrace
33
34
from pymc .backends .zarr import ZarrChain
34
35
from pymc .blocking import DictToArrayBijection
35
36
from pymc .exceptions import SamplingError
@@ -105,8 +106,9 @@ def __init__(
105
106
rng_state : RandomGeneratorState ,
106
107
blas_cores ,
107
108
chain : int ,
108
- zarr_chains : list [ZarrChain ] | bytes | None = None ,
109
- zarr_chains_is_pickled : bool = False ,
109
+ traces : list [IBaseTrace ] | bytes | None = None ,
110
+ traces_is_pickled : bool = False ,
111
+ zarr_recording : bool = False ,
110
112
):
111
113
# Because of https://github.com/numpy/numpy/issues/27727, we can't send
112
114
# the rng instance to the child process because pickling (copying) looses
@@ -117,13 +119,12 @@ def __init__(
117
119
self ._step_method = step_method
118
120
self ._step_method_is_pickled = step_method_is_pickled
119
121
self .chain = chain
120
- self ._zarr_recording = False
121
- self ._zarr_chain : ZarrChain | None = None
122
- if zarr_chains_is_pickled :
123
- self ._zarr_chain = cloudpickle .loads (zarr_chains )[self .chain ]
124
- elif zarr_chains is not None :
125
- self ._zarr_chain = cast (list [ZarrChain ], zarr_chains )[self .chain ]
126
- self ._zarr_recording = self ._zarr_chain is not None
122
+ self ._zarr_recording = zarr_recording
123
+ self ._trace : IBaseTrace | None = None
124
+ if traces_is_pickled :
125
+ self ._trace = cloudpickle .loads (traces )[self .chain ]
126
+ elif traces is not None :
127
+ self ._trace = cast (list [IBaseTrace ], traces )[self .chain ]
127
128
128
129
self ._shared_point = shared_point
129
130
self ._rng = rng
@@ -165,7 +166,7 @@ def run(self):
165
166
166
167
def _link_step_to_zarrchain (self ):
167
168
if self ._zarr_recording :
168
- self ._zarr_chain .link_stepper (self ._step_method )
169
+ self ._trace .link_stepper (self ._step_method )
169
170
170
171
def _wait_for_abortion (self ):
171
172
while True :
@@ -194,6 +195,24 @@ def _start_loop(self):
194
195
195
196
draw = 0
196
197
tuning = True
198
+ if self ._zarr_recording :
199
+ trace = self ._trace
200
+ stored_draw_idx = trace ._sampling_state .draw_idx [self .chain ]
201
+ stored_sampling_state = trace ._sampling_state .sampling_state [self .chain ]
202
+ if stored_draw_idx > 0 :
203
+ if stored_sampling_state is not None :
204
+ self ._step_method .sampling_state = stored_sampling_state
205
+ else :
206
+ raise RuntimeError (
207
+ "Cannot use the supplied ZarrTrace to restart sampling because "
208
+ "it has no sampling_state information stored. You will have to "
209
+ "resample from scratch."
210
+ )
211
+ draw = stored_draw_idx
212
+ self ._write_point (trace .get_mcmc_point ())
213
+ else :
214
+ # Store starting point in trace's mcmc_point
215
+ trace .set_mcmc_point (self ._point )
197
216
198
217
msg = self ._recv_msg ()
199
218
if msg [0 ] == "abort" :
@@ -220,7 +239,7 @@ def _start_loop(self):
220
239
raise KeyboardInterrupt ()
221
240
elif msg [0 ] == "write_next" :
222
241
if zarr_recording :
223
- self ._zarr_chain .record (point , stats )
242
+ self ._trace .record (point , stats )
224
243
self ._write_point (point )
225
244
is_last = draw + 1 == self ._draws + self ._tune
226
245
self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats ))
@@ -247,8 +266,9 @@ def __init__(
247
266
start : dict [str , np .ndarray ],
248
267
blas_cores ,
249
268
mp_ctx ,
250
- zarr_chains : list [ZarrChain ] | None = None ,
251
- zarr_chains_pickled : bytes | None = None ,
269
+ traces : Sequence [IBaseTrace ] | None = None ,
270
+ traces_pickled : bytes | None = None ,
271
+ zarr_recording : bool = False ,
252
272
):
253
273
self .chain = chain
254
274
process_name = f"worker_chain_{ chain } "
@@ -271,15 +291,15 @@ def __init__(
271
291
self ._readable = True
272
292
self ._num_samples = 0
273
293
274
- zarr_chains_send : list [ ZarrChain ] | bytes | None = None
275
- if zarr_chains_pickled is not None :
276
- zarr_chains_send = zarr_chains_pickled
277
- elif zarr_chains is not None :
294
+ traces_send : Sequence [ IBaseTrace ] | bytes | None = None
295
+ if traces_pickled is not None :
296
+ traces_send = traces_pickled
297
+ elif traces is not None :
278
298
if mp_ctx .get_start_method () == "spawn" :
279
299
raise ValueError (
280
- "please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
300
+ "please provide a pre-pickled traces when multiprocessing start method is 'spawn'"
281
301
)
282
- zarr_chains_send = zarr_chains
302
+ traces_send = traces
283
303
284
304
if step_method_pickled is not None :
285
305
step_method_send = step_method_pickled
@@ -305,8 +325,9 @@ def __init__(
305
325
get_state_from_generator (rng ),
306
326
blas_cores ,
307
327
self .chain ,
308
- zarr_chains_send ,
309
- zarr_chains_pickled is not None ,
328
+ traces_send ,
329
+ traces_pickled is not None ,
330
+ zarr_recording ,
310
331
),
311
332
)
312
333
self ._process .start ()
@@ -429,7 +450,7 @@ def __init__(
429
450
progressbar_theme : Theme | None = default_progress_theme ,
430
451
blas_cores : int | None = None ,
431
452
mp_ctx = None ,
432
- zarr_chains : list [ ZarrChain ] | None = None ,
453
+ traces : Sequence [ IBaseTrace ] | None = None ,
433
454
):
434
455
if any (len (arg ) != chains for arg in [rngs , start_points ]):
435
456
raise ValueError (f"Number of rngs and start_points must be { chains } ." )
@@ -450,15 +471,14 @@ def __init__(
450
471
mp_ctx = multiprocessing .get_context (mp_ctx )
451
472
452
473
step_method_pickled = None
453
- zarr_chains_pickled = None
454
- self .zarr_recording = False
455
- if zarr_chains is not None :
456
- assert all (isinstance (zarr_chain , ZarrChain ) for zarr_chain in zarr_chains )
457
- self .zarr_recording = True
474
+ traces_pickled = None
475
+ self .zarr_recording = traces is not None and all (
476
+ isinstance (trace , ZarrChain ) for trace in traces
477
+ )
458
478
if mp_ctx .get_start_method () != "fork" :
459
479
step_method_pickled = cloudpickle .dumps (step_method , protocol = - 1 )
460
- if zarr_chains is not None :
461
- zarr_chains_pickled = cloudpickle .dumps (zarr_chains , protocol = - 1 )
480
+ if traces is not None :
481
+ traces_pickled = cloudpickle .dumps (traces , protocol = - 1 )
462
482
463
483
self ._samplers = [
464
484
ProcessAdapter (
@@ -471,8 +491,9 @@ def __init__(
471
491
start ,
472
492
blas_cores ,
473
493
mp_ctx ,
474
- zarr_chains = zarr_chains ,
475
- zarr_chains_pickled = zarr_chains_pickled ,
494
+ traces = traces ,
495
+ traces_pickled = traces_pickled ,
496
+ zarr_recording = self .zarr_recording ,
476
497
)
477
498
for chain , rng , start in zip (range (chains ), rngs , start_points )
478
499
]
0 commit comments