Skip to content

Add optional support for McBackend-based trace backends #6501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
pymc/tests/step_methods/hmc/test_quadpotential.py

- |
pymc/tests/backends/test_mcbackend.py
pymc/tests/distributions/test_truncated.py
pymc/tests/logprob/test_abstract.py
pymc/tests/logprob/test_censoring.py
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ dependencies:
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
- numdifftools>=0.9.40
- mcbackend>=0.3.0
1 change: 1 addition & 0 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ dependencies:
- types-cachetools
- pip:
- numdifftools>=0.9.40
- mcbackend>=0.3.0
1 change: 1 addition & 0 deletions conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ dependencies:
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
- numdifftools>=0.9.40
- mcbackend>=0.3.0
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ dependencies:
- types-cachetools
- pip:
- numdifftools>=0.9.40
- mcbackend>=0.3.0
37 changes: 35 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,19 @@
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep
from pymc.step_methods.compound import BlockedStep, CompoundStep, StatsBijection

HAS_MCB = False
try:
from mcbackend import Backend, NumPyBackend

from pymc.backends.mcbackend import ChainRecordAdapter, make_runmeta

TraceOrBackend = Union[BaseTrace, Backend]
HAS_MCB = True
except ImportError:
TraceOrBackend = BaseTrace # type: ignore


__all__ = ["to_inference_data", "predictions_to_inference_data"]

Expand Down Expand Up @@ -99,7 +111,7 @@ def _init_trace(

def init_traces(
*,
backend: Optional[BaseTrace],
backend: Optional[TraceOrBackend],
chains: int,
expected_length: int,
step: Union[BlockedStep, CompoundStep],
Expand All @@ -108,6 +120,27 @@ def init_traces(
model: Model,
) -> Sequence[IBaseTrace]:
"""Initializes a trace recorder for each chain."""
if HAS_MCB and backend is None:
backend = NumPyBackend(preallocate=expected_length)
if HAS_MCB and isinstance(backend, Backend):
run = backend.init_run(
make_runmeta(
var_dtypes=var_dtypes,
var_shapes=var_shapes,
step=step,
model=model,
)
)
statsbj = StatsBijection(step.stats_dtypes)
return [
ChainRecordAdapter(
chain=run.init_chain(chain_number=chain_number),
stats_bijection=statsbj,
)
for chain_number in range(chains)
]

assert backend is None or isinstance(backend, BaseTrace)
return [
_init_trace(
expected_length=expected_length,
Expand Down
93 changes: 59 additions & 34 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class IBaseTrace(ABC, Sized):
varnames: List[str]
"""Names of tracked variables."""

sampler_vars: List[Dict[str, type]]
sampler_vars: List[Dict[str, Union[type, np.dtype]]]
"""Sampler stats for each sampler."""

def __len__(self):
Expand All @@ -79,23 +79,27 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""
raise NotImplementedError()

def get_sampler_stats(self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1):
def get_sampler_stats(
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
) -> np.ndarray:
"""Get sampler statistics from the trace.

Parameters
----------
stat_name: str
sampler_idx: int or None
burn: int
thin: int
stat_name : str
Name of the stat to fetch.
sampler_idx : int or None
Index of the sampler to get the stat from.
burn : int
Draws to skip from the start.
thin : int
Stepsize for the slice.

Returns
-------
If the `sampler_idx` is specified, return the statistic with
the given name in a numpy array. If it is not specified and there
is more than one sampler that provides this statistic, return
a numpy array of shape (m, n), where `m` is the number of
such samplers, and `n` is the number of samples.
stats : np.ndarray
If `sampler_idx` was specified, the shape should be `(draws, samplers)`.
Otherwise, the shape should be `(draws,)`.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -220,23 +224,31 @@ def __getitem__(self, idx):
except (ValueError, TypeError): # Passed variable or variable name.
raise ValueError("Can only index with slice or integer")

def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
def get_sampler_stats(
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
) -> np.ndarray:
"""Get sampler statistics from the trace.

Note: This implementation attempts to squeeze object arrays into a consistent dtype,
# which can change their shape in hard-to-predict ways.
# See https://github.com/pymc-devs/pymc/issues/6207

Parameters
----------
stat_name: str
sampler_idx: int or None
burn: int
thin: int
stat_name : str
Name of the stat to fetch.
sampler_idx : int or None
Index of the sampler to get the stat from.
burn : int
Draws to skip from the start.
thin : int
Stepsize for the slice.

Returns
-------
If the `sampler_idx` is specified, return the statistic with
the given name in a numpy array. If it is not specified and there
is more than one sampler that provides this statistic, return
a numpy array of shape (m, n), where `m` is the number of
such samplers, and `n` is the number of samples.
stats : np.ndarray
If `sampler_idx` was specified, the shape should be `(draws, samplers)`.
Otherwise, the shape should be `(draws,)`.
"""
if sampler_idx is not None:
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
Expand All @@ -254,14 +266,16 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):

if vals.dtype == np.dtype(object):
try:
vals = np.vstack(vals)
vals = np.vstack(list(vals))
except ValueError:
# Most likely due to non-identical shapes. Just stick with the object-array.
pass

return vals

def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
def _get_sampler_stats(
self, stat_name: str, sampler_idx: int, burn: int, thin: int
) -> np.ndarray:
"""Get sampler statistics."""
raise NotImplementedError()

Expand Down Expand Up @@ -476,23 +490,34 @@ def get_sampler_stats(
combine: bool = True,
chains: Optional[Union[int, Sequence[int]]] = None,
squeeze: bool = True,
):
) -> Union[List[np.ndarray], np.ndarray]:
"""Get sampler statistics from the trace.

Note: This implementation attempts to squeeze object arrays into a consistent dtype,
# which can change their shape in hard-to-predict ways.
# See https://github.com/pymc-devs/pymc/issues/6207

Parameters
----------
stat_name: str
sampler_idx: int or None
burn: int
thin: int
stat_name : str
Name of the stat to fetch.
sampler_idx : int or None
Index of the sampler to get the stat from.
burn : int
Draws to skip from the start.
thin : int
Stepsize for the slice.
combine : bool
If True, results from `chains` will be concatenated.
squeeze : bool
Return a single array element if the resulting list of
values only has one element. If False, the result will
always be a list of arrays, even if `combine` is True.

Returns
-------
If the `sampler_idx` is specified, return the statistic with
the given name in a numpy array. If it is not specified and there
is more than one sampler that provides this statistic, return
a numpy array of shape (m, n), where `m` is the number of
such samplers, and `n` is the number of samples.
stats : np.ndarray
List or ndarray depending on parameters.
"""
if stat_name not in self.stat_names:
raise KeyError("Unknown sampler statistic %s" % stat_name)
Expand Down Expand Up @@ -543,7 +568,7 @@ def points(self, chains=None):
return itl.chain.from_iterable(self._straces[chain] for chain in chains)


def _squeeze_cat(results, combine, squeeze):
def _squeeze_cat(results, combine: bool, squeeze: bool):
"""Squeeze and concatenate the results depending on values of
`combine` and `squeeze`."""
if combine:
Expand Down
Loading