Skip to content

Allow for pymc native samplers to resume sampling from ZarrTrace #7687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.backends.zarr import ZarrTrace
from pymc.backends.zarr import TraceAlreadyInitialized, ZarrTrace
from pymc.blocking import PointType
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep
Expand Down Expand Up @@ -132,15 +132,41 @@ def init_traces(
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
"""Initialize a trace recorder for each chain."""
if isinstance(backend, ZarrTrace):
backend.init_trace(
chains=chains,
draws=expected_length - tune,
tune=tune,
step=step,
model=model,
vars=trace_vars,
test_point=initial_point,
)
try:
backend.init_trace(
chains=chains,
draws=expected_length - tune,
tune=tune,
step=step,
model=model,
vars=trace_vars,
test_point=initial_point,
)
except TraceAlreadyInitialized:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just InitializedTrace? Seems a little verbose!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine to me, it's an internal thing

# Trace has already been initialized. We need to make sure that the
# tracked variable names and the number of chains match, and then resize
# the zarr groups to the desired number of draws and tune.
backend.assert_model_and_step_are_compatible(
step=step,
model=model,
vars=trace_vars,
)
assert backend.posterior.chain.size == chains, (
f"The requested number of chains {chains} does not match the number "
f"of chains stored in the trace ({backend.posterior.chain.size})."
)
vars, var_names = backend.parse_varnames(model=model, vars=trace_vars)
backend.link_model_and_step(
chains=chains,
draws=expected_length - tune,
tune=tune,
step=step,
model=model,
vars=vars,
var_names=var_names,
test_point=initial_point,
)
backend.resize(tune=tune, draws=expected_length - tune)
return None, backend.straces
if HAS_MCB and isinstance(backend, Backend):
return init_chain_adapters(
Expand Down
66 changes: 66 additions & 0 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile
from pymc.step_methods.compound import BlockedStep, CompoundStep, CompoundStepState, StepMethodState
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
Expand All @@ -56,6 +57,8 @@
sampler_vars: list[dict[str, type | np.dtype]]
"""Sampler stats for each sampler."""

_step_method: BlockedStep | CompoundStep | None

def __len__(self):
"""Length of the chain."""
raise NotImplementedError()
Expand Down Expand Up @@ -132,6 +135,69 @@
"""
pass

def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[int, int]:
"""Get number of completed draws and divergences in the trace.

This is a helper function to start the ProgressBarManager when resuming sampling
from an existing trace.

Parameters
----------
chain_specific : bool
If ``True``, only the completed draws and divergences on the current chain
are returned. If ``False``, the draws and divergences across all chains are
returned. WARNING: many BaseTrace backends are not aware of the information
stored in other chains and will raise a ``ValueError`` if passed ``False``.

Returns
-------
draws : int
Number of draws in the current chain or across all chains.
divergences : int
Number of divergences in the current chain or across all chains.
"""
raise NotImplementedError()

def link_stepper(self, step_method: BlockedStep | CompoundStep):
"""Provide a reference to the step method used during sampling.

This reference can be used to facilite writing the stepper's sampling state
each time the samples are flushed into the storage.
"""
self._step_method = step_method

def store_sampling_state(self, sampling_state: StepMethodState | CompoundStepState):
self._sampling_state = sampling_state

def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None):
"""Record the sampling state information to the store's ``_sampling_state`` group.

The sampling state includes the number of draws taken so far (``draw_idx``) and
the step method's ``sampling_state``.

Parameters
----------
step : BlockedStep | CompoundStep | None
The step method from which to take the ``sampling_state``. If ``None``,
the ``step`` is taken to be the step method that was linked to the
trace when calling :meth:`~IBaseTrace.link_stepper`. If this method was never
called, no step method ``sampling_state`` information is stored in the
chain.
"""
if step is None:
step = self._step_method

Check warning on line 188 in pymc/backends/base.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/base.py#L188

Added line #L188 was not covered by tests
if step is not None:
self.store_sampling_state(step.sampling_state)

def get_stored_draw_and_state(self) -> tuple[int, StepMethodState | CompoundStepState | None]:
return 0, None

def get_mcmc_point(self) -> dict[str, np.ndarray]:
raise NotImplementedError()

def set_mcmc_point(self, mcmc_point: Mapping[str, np.ndarray]):
pass


class BaseTrace(IBaseTrace):
"""Base trace object.
Expand Down
36 changes: 36 additions & 0 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,42 @@
def point(self, idx: int) -> dict[str, np.ndarray]:
return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()])

def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[int, int]:
"""Get number of completed draws and divergences in the trace.

This is a helper function to start the ProgressBarManager when resuming sampling
from an existing trace.

Parameters
----------
chain_specific : bool
If ``True``, only the completed draws and divergences on the current chain
are returned. If ``False``, the draws and divergences across all chains are
returned. WARNING: many BaseTrace backends are not aware of the information
stored in other chains and will raise a ``ValueError`` if passed ``False``.

Returns
-------
draws : int
Number of draws in the current chain or across all chains.
divergences : int
Number of divergences in the current chain or across all chains.
"""
if not chain_specific:
raise ValueError(

Check warning on line 226 in pymc/backends/mcbackend.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/mcbackend.py#L226

Added line #L226 was not covered by tests
"NDArray traces are not aware of the number of draws and divergences "
"recorded in other chains. Please call this method using "
"chain_specific=True"
)
try:
divergent_draws = self.get_sampler_stats("divergence")
if divergent_draws.ndim > 1:
divergent_draws = divergent_draws.sum(axis=-1)
divergences = sum(divergent_draws > 0)

Check warning on line 235 in pymc/backends/mcbackend.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/mcbackend.py#L233-L235

Added lines #L233 - L235 were not covered by tests
except KeyError:
divergences = 0
return len(self), divergences


def make_runmeta_and_point_fn(
*,
Expand Down
35 changes: 35 additions & 0 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,41 @@
idx = int(idx)
return {varname: values[idx] for varname, values in self.samples.items()}

def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[int, int]:
"""Get number of completed draws and divergences in the trace.

This is a helper function to start the ProgressBarManager when resuming sampling
from an existing trace.

Parameters
----------
chain_specific : bool
If ``True``, only the completed draws and divergences on the current chain
are returned. If ``False``, the draws and divergences across all chains are
returned. WARNING: many BaseTrace backends are not aware of the information
stored in other chains and will raise a ``ValueError`` if passed ``False``.

Returns
-------
draws : int
Number of draws in the current chain or across all chains.
divergences : int
Number of divergences in the current chain or across all chains.
"""
if not chain_specific:
raise ValueError(

Check warning on line 226 in pymc/backends/ndarray.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/ndarray.py#L226

Added line #L226 was not covered by tests
"NDArray traces are not aware of the number of draws and divergences "
"recorded in other chains. Please call this method using "
"chain_specific=True"
)
divergent_draws = np.zeros(len(self), dtype="int")
for sampler_stats in self._stats:
for key, data in sampler_stats.items():
if "divergence" in key:
divergent_draws += np.asarray(data)

Check warning on line 235 in pymc/backends/ndarray.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/ndarray.py#L235

Added line #L235 was not covered by tests
divergences = sum(divergent_draws > 0)
return len(self), divergences


def _slice_as_ndarray(strace, idx):
sliced = NDArray(model=strace.model, vars=strace.vars)
Expand Down
Loading