From 27ea7815021ad720be08c60b9e82feba5c52bd0f Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 5 Feb 2023 16:42:28 +0100 Subject: [PATCH 1/4] More robust tune stat fetching and type hints The mcmc module relies on the `"tune"` stat to figure out the number of tune/draw iterations post sampling. These changes remove reliance on any weird squeeze-cat magic. --- pymc/backends/base.py | 93 +++++++++++++++++++++++++--------------- pymc/backends/ndarray.py | 4 +- pymc/sampling/mcmc.py | 6 +-- 3 files changed, 64 insertions(+), 39 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 2ee3193f1f..2929d3672b 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -58,7 +58,7 @@ class IBaseTrace(ABC, Sized): varnames: List[str] """Names of tracked variables.""" - sampler_vars: List[Dict[str, type]] + sampler_vars: List[Dict[str, Union[type, np.dtype]]] """Sampler stats for each sampler.""" def __len__(self): @@ -79,23 +79,27 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: """ raise NotImplementedError() - def get_sampler_stats(self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1): + def get_sampler_stats( + self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + ) -> np.ndarray: """Get sampler statistics from the trace. Parameters ---------- - stat_name: str - sampler_idx: int or None - burn: int - thin: int + stat_name : str + Name of the stat to fetch. + sampler_idx : int or None + Index of the sampler to get the stat from. + burn : int + Draws to skip from the start. + thin : int + Stepsize for the slice. Returns ------- - If the `sampler_idx` is specified, return the statistic with - the given name in a numpy array. If it is not specified and there - is more than one sampler that provides this statistic, return - a numpy array of shape (m, n), where `m` is the number of - such samplers, and `n` is the number of samples. + stats : np.ndarray + If `sampler_idx` was specified, the shape should be `(draws, samplers)`. + Otherwise, the shape should be `(draws,)`. """ raise NotImplementedError() @@ -220,23 +224,31 @@ def __getitem__(self, idx): except (ValueError, TypeError): # Passed variable or variable name. raise ValueError("Can only index with slice or integer") - def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1): + def get_sampler_stats( + self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + ) -> np.ndarray: """Get sampler statistics from the trace. + Note: This implementation attempts to squeeze object arrays into a consistent dtype, + # which can change their shape in hard-to-predict ways. + # See https://github.com/pymc-devs/pymc/issues/6207 + Parameters ---------- - stat_name: str - sampler_idx: int or None - burn: int - thin: int + stat_name : str + Name of the stat to fetch. + sampler_idx : int or None + Index of the sampler to get the stat from. + burn : int + Draws to skip from the start. + thin : int + Stepsize for the slice. Returns ------- - If the `sampler_idx` is specified, return the statistic with - the given name in a numpy array. If it is not specified and there - is more than one sampler that provides this statistic, return - a numpy array of shape (m, n), where `m` is the number of - such samplers, and `n` is the number of samples. + stats : np.ndarray + If `sampler_idx` was specified, the shape should be `(draws, samplers)`. + Otherwise, the shape should be `(draws,)`. """ if sampler_idx is not None: return self._get_sampler_stats(stat_name, sampler_idx, burn, thin) @@ -254,14 +266,16 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1): if vals.dtype == np.dtype(object): try: - vals = np.vstack(vals) + vals = np.vstack(list(vals)) except ValueError: # Most likely due to non-identical shapes. Just stick with the object-array. pass return vals - def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin): + def _get_sampler_stats( + self, stat_name: str, sampler_idx: int, burn: int, thin: int + ) -> np.ndarray: """Get sampler statistics.""" raise NotImplementedError() @@ -476,23 +490,34 @@ def get_sampler_stats( combine: bool = True, chains: Optional[Union[int, Sequence[int]]] = None, squeeze: bool = True, - ): + ) -> Union[List[np.ndarray], np.ndarray]: """Get sampler statistics from the trace. + Note: This implementation attempts to squeeze object arrays into a consistent dtype, + # which can change their shape in hard-to-predict ways. + # See https://github.com/pymc-devs/pymc/issues/6207 + Parameters ---------- - stat_name: str - sampler_idx: int or None - burn: int - thin: int + stat_name : str + Name of the stat to fetch. + sampler_idx : int or None + Index of the sampler to get the stat from. + burn : int + Draws to skip from the start. + thin : int + Stepsize for the slice. + combine : bool + If True, results from `chains` will be concatenated. + squeeze : bool + Return a single array element if the resulting list of + values only has one element. If False, the result will + always be a list of arrays, even if `combine` is True. Returns ------- - If the `sampler_idx` is specified, return the statistic with - the given name in a numpy array. If it is not specified and there - is more than one sampler that provides this statistic, return - a numpy array of shape (m, n), where `m` is the number of - such samplers, and `n` is the number of samples. + stats : np.ndarray + List or ndarray depending on parameters. """ if stat_name not in self.stat_names: raise KeyError("Unknown sampler statistic %s" % stat_name) @@ -543,7 +568,7 @@ def points(self, chains=None): return itl.chain.from_iterable(self._straces[chain] for chain in chains) -def _squeeze_cat(results, combine, squeeze): +def _squeeze_cat(results, combine: bool, squeeze: bool): """Squeeze and concatenate the results depending on values of `combine` and `squeeze`.""" if combine: diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index bacaf23c53..4f68d79e8c 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -119,7 +119,9 @@ def record(self, point, sampler_stats=None) -> None: data[key][self.draw_idx] = val self.draw_idx += 1 - def _get_sampler_stats(self, varname, sampler_idx, burn, thin): + def _get_sampler_stats( + self, varname: str, sampler_idx: int, burn: int, thin: int + ) -> np.ndarray: return self._stats[sampler_idx][varname][burn::thin] def close(self): diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 24e744b85b..8cac5bc385 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -580,10 +580,8 @@ def sample( # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: - stat = mtrace.get_sampler_stats("tune", chains=0) - # when CompoundStep is used, the stat is 2 dimensional! - if len(stat.shape) == 2: - stat = stat[:, 0] + # Get the tune stat directly from chain 0, sampler 0 + stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0) stat = tuple(stat) n_tune = stat.count(True) n_draws = stat.count(False) From 6f795260164dfdb9db74b9d888ed36c0033fe767 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 5 Feb 2023 17:32:49 +0100 Subject: [PATCH 2/4] Add optional McBackend support --- .github/workflows/tests.yml | 1 + pymc/backends/__init__.py | 35 ++++- pymc/backends/mcbackend.py | 203 ++++++++++++++++++++++++++ pymc/sampling/mcmc.py | 6 +- pymc/tests/backends/test_mcbackend.py | 120 +++++++++++++++ pymc/tests/sampling/test_mcmc.py | 4 + 6 files changed, 364 insertions(+), 5 deletions(-) create mode 100644 pymc/backends/mcbackend.py create mode 100644 pymc/tests/backends/test_mcbackend.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 443e3e376b..d30fb22b87 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -86,6 +86,7 @@ jobs: pymc/tests/step_methods/hmc/test_quadpotential.py - | + pymc/tests/backends/test_mcbackend.py pymc/tests/distributions/test_truncated.py pymc/tests/logprob/test_abstract.py pymc/tests/logprob/test_censoring.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index fbfe8914a9..61ec861f44 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -69,7 +69,19 @@ from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray from pymc.model import Model -from pymc.step_methods.compound import BlockedStep, CompoundStep +from pymc.step_methods.compound import BlockedStep, CompoundStep, StatsBijection + +HAS_MCB = False +try: + from mcbackend import Backend + + from pymc.backends.mcbackend import ChainRecordAdapter, make_runmeta + + TraceOrBackend = Union[BaseTrace, Backend] + HAS_MCB = True +except ImportError: + TraceOrBackend = BaseTrace # type: ignore + __all__ = ["to_inference_data", "predictions_to_inference_data"] @@ -99,7 +111,7 @@ def _init_trace( def init_traces( *, - backend: Optional[BaseTrace], + backend: Optional[TraceOrBackend], chains: int, expected_length: int, step: Union[BlockedStep, CompoundStep], @@ -108,6 +120,25 @@ def init_traces( model: Model, ) -> Sequence[IBaseTrace]: """Initializes a trace recorder for each chain.""" + if HAS_MCB and isinstance(backend, Backend): + run = backend.init_run( + make_runmeta( + var_dtypes=var_dtypes, + var_shapes=var_shapes, + step=step, + model=model, + ) + ) + statsbj = StatsBijection(step.stats_dtypes) + return [ + ChainRecordAdapter( + chain=run.init_chain(chain_number=chain_number), + stats_bijection=statsbj, + ) + for chain_number in range(chains) + ] + + assert backend is None or isinstance(backend, BaseTrace) return [ _init_trace( expected_length=expected_length, diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py new file mode 100644 index 0000000000..c065bc9618 --- /dev/null +++ b/pymc/backends/mcbackend.py @@ -0,0 +1,203 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +import hagelkorn +import mcbackend as mcb +import numpy as np + +from mcbackend.npproto.utils import ndarray_from_numpy +from pytensor.compile.sharedvalue import SharedVariable +from pytensor.graph.basic import Constant + +from pymc.backends.base import IBaseTrace +from pymc.model import Model +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + flatten_steps, +) + + +def find_data(pmodel: Model) -> List[mcb.DataVariable]: + """Extracts data variables from a model.""" + observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs} + dvars = [] + # All data containers are named vars! + for name, var in pmodel.named_vars.items(): + dv = mcb.DataVariable(name) + if isinstance(var, Constant): + dv.value = ndarray_from_numpy(var.data) + elif isinstance(var, SharedVariable): + dv.value = ndarray_from_numpy(var.get_value()) + else: + continue + dv.dims = list(pmodel.named_vars_to_dims.get(name, [])) + dv.is_observed = var in observed_rvs + dvars.append(dv) + return dvars + + +def make_runmeta( + *, + var_dtypes: Dict[str, np.dtype], + var_shapes: Dict[str, Sequence[int]], + step: Union[CompoundStep, BlockedStep], + model: Model, +) -> mcb.RunMeta: + """Create an McBackend metadata description for the MCMC run. + + Parameters + ---------- + var_dtypes : dict + Variable names and corresponding NumPy dtypes. + var_shapes : dict + Variable names and corresponding shape tuples. + step : CompoundStep or BlockedStep + The step method that iterates the MCMC. + model : pm.Model + The current PyMC model. + + Returns + ------- + rmeta : mcb.RunMeta + Metadata about the model and MCMC sampling run. + """ + # Replace None with "" in RV dims. + rv_dims = { + name: ((dname or "") for dname in dims) for name, dims in model.named_vars_to_dims.items() + } + free_rv_names = [rv.name for rv in model.free_RVs] + variables = [ + mcb.Variable( + name, + str(var_dtypes[name]), + list(var_shapes[name]), + dims=list(rv_dims[name]) if name in rv_dims else [], + is_deterministic=(name not in free_rv_names), + ) + for name in var_dtypes.keys() + ] + + sample_stats = [ + mcb.Variable("tune", "bool"), + ] + + # In PyMC the sampler stats are grouped by the sampler. + # ⚠ PyMC currently does not inform backends about shapes/dims of sampler stats. + steps = flatten_steps(step) + for s, sm in enumerate(steps): + for statstypes in sm.stats_dtypes: + for statname, dtype in statstypes.items(): + sname = f"sampler_{s}__{statname}" + svar = mcb.Variable( + name=sname, + dtype=np.dtype(dtype).name, + # This 👇 is needed until samplers provide shapes ahead of time. + undefined_ndim=True, + ) + sample_stats.append(svar) + + coordinates = [ + mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals))) + for dname, cvals in model.coords.items() + if cvals is not None + ] + meta = mcb.RunMeta( + rid=hagelkorn.random(), + variables=variables, + coordinates=coordinates, + sample_stats=sample_stats, + data=find_data(model), + ) + return meta + + +class ChainRecordAdapter(IBaseTrace): + """Wraps an McBackend ``Chain`` as an ``IBaseTrace``.""" + + def __init__(self, chain: mcb.Chain, stats_bijection: StatsBijection) -> None: + # Assign attributes required by IBaseTrace + self.chain = chain.cmeta.chain_number + self.varnames = [v.name for v in chain.rmeta.variables] + stats_dtypes = {s.name: np.dtype(s.dtype) for s in chain.rmeta.sample_stats} + self.sampler_vars = [ + {stepname: stats_dtypes[flatname] for flatname, stepname in sstats} + for sstats in stats_bijection._stat_groups + ] + + self._chain = chain + self._statsbj = stats_bijection + super().__init__() + + def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + return self._chain.append(draw, self._statsbj.map(stats)) + + def __len__(self): + return len(self._chain) + + def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: + return self._chain.get_draws(varname, slice(0, None, thin)) + + def get_sampler_stats( + self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + ) -> np.ndarray: + # Fetching for a specific sampler is easy + if sampler_idx is not None: + return self._chain.get_stats( + f"sampler_{sampler_idx}__{stat_name}", slice(0, None, thin) + ) + # To fetch for all samplers, we must collect the arrays one by one. + stats_dict = { + stat.name: self._chain.get_stats(stat.name, slice(0, None, thin)) + for stat in self._chain.rmeta.sample_stats + if stat_name in stat.name + } + if not stats_dict: + raise KeyError(f"No stat '{stat_name}' was recorded.") + stats_list = self._statsbj.rmap(stats_dict) + stats_arrays = [] + for sd in stats_list: + if not sd: + stats_arrays.append(np.empty((), dtype=object)) + else: + stats_arrays.append(tuple(sd.values())[0]) + if sampler_idx is not None: + return stats_arrays[sampler_idx] + return np.array(stats_arrays).T + + def _slice(self, idx: slice) -> "IBaseTrace": + # Get the integer indices + start, stop, step = idx.indices(len(self)) + indices = np.arange(start, stop, step) + + # Create a NumPyChain for the sliced data + nchain = mcb.backends.numpy.NumPyChain( + self._chain.cmeta, self._chain.rmeta, preallocate=len(indices) + ) + + # Copy at selected indices and append them to the new chain. + # This may be slow, but NumPyChain currently don't have a batch-insert or slice API. + vnames = [v.name for v in nchain.variables.values()] + snames = [s.name for s in nchain.sample_stats.values()] + for i in indices: + draw = self._chain.get_draws_at(i, var_names=vnames) + stats = self._chain.get_stats_at(i, stat_names=snames) + nchain.append(draw, stats) + return ChainRecordAdapter(nchain, self._statsbj) + + def point(self, idx: int) -> Dict[str, np.ndarray]: + return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()]) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 8cac5bc385..cb5b79260e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -32,8 +32,8 @@ import pymc as pm -from pymc.backends import init_traces -from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains +from pymc.backends import TraceOrBackend, init_traces +from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -227,7 +227,7 @@ def sample( init: str = "auto", n_init: int = 200_000, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, - trace: Optional[BaseTrace] = None, + trace: Optional[TraceOrBackend] = None, chains: Optional[int] = None, cores: Optional[int] = None, tune: int = 1000, diff --git a/pymc/tests/backends/test_mcbackend.py b/pymc/tests/backends/test_mcbackend.py new file mode 100644 index 0000000000..2cbc6121c7 --- /dev/null +++ b/pymc/tests/backends/test_mcbackend.py @@ -0,0 +1,120 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import arviz +import numpy as np +import pytest + +import pymc as pm + +from pymc.backends import init_traces + +try: + import mcbackend as mcb +except ImportError: + pytest.skip("Requires McBackend to be installed.") + +from pymc.backends.mcbackend import ChainRecordAdapter, make_runmeta + + +@pytest.fixture +def simple_model(): + seconds = np.linspace(0, 5) + observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :]) + with pm.Model( + coords={ + "condition": ["A", "B", "C"], + } + ) as pmodel: + x = pm.ConstantData("seconds", seconds, dims="time") + a = pm.Normal("scalar") + b = pm.Uniform("vector", dims="condition") + pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time")) + pm.Bernoulli("integer", p=0.5) + obs = pm.MutableData("obs", observations, dims=("condition", "time")) + pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time")) + return pmodel + + +def test_make_runmeta(simple_model): + with simple_model: + step = pm.DEMetropolisZ() + dtypes = {rv.name: rv.dtype for rv in step.vars} + shapes = {rv.name: rv.shape.eval() for rv in step.vars} + rmeta = make_runmeta( + var_dtypes=dtypes, + var_shapes=shapes, + step=step, + model=simple_model, + ) + assert isinstance(rmeta, mcb.RunMeta) + assert len(rmeta.variables) == len(dtypes) + assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0]) + pass + + +def test_init_traces(simple_model): + with simple_model: + step = pm.DEMetropolisZ() + dtypes = {rv.name: rv.dtype for rv in step.vars} + shapes = {rv.name: rv.shape.eval() for rv in step.vars} + traces = init_traces( + backend=mcb.NumPyBackend(), + chains=2, + expected_length=70, + step=step, + var_dtypes=dtypes, + var_shapes=shapes, + model=simple_model, + ) + assert isinstance(traces, list) + assert len(traces) == 2 + assert isinstance(traces[0], ChainRecordAdapter) + assert isinstance(traces[0]._chain, mcb.backends.numpy.NumPyChain) + pass + + +class TestMcBackendSampling: + def test_multitrace_wrap(self, simple_model): + with simple_model: + mtrace = pm.sample( + trace=mcb.NumPyBackend(), + tune=5, + draws=7, + cores=1, + chains=3, + step=pm.Metropolis(), + discard_tuned_samples=False, + return_inferencedata=False, + ) + assert isinstance(mtrace, pm.backends.base.MultiTrace) + tune = mtrace._straces[0].get_sampler_stats("tune") + assert isinstance(tune, np.ndarray) + assert tune.shape == (12, 3) + pass + + @pytest.mark.parametrize("cores", [1, 3]) + def test_simple_model(self, simple_model, cores): + with simple_model: + idata = pm.sample( + trace=mcb.NumPyBackend(), + tune=5, + draws=7, + cores=cores, + chains=3, + discard_tuned_samples=False, + ) + assert isinstance(idata, arviz.InferenceData) + assert idata.warmup_posterior.sizes["draw"] == 5 + assert idata.posterior.sizes["draw"] == 7 + pass diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index 85686d235e..cbda1399ab 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -701,6 +701,10 @@ def test_keep_warning_stat_setting(keep_warning_stat): # This tests flattens so we don't have to be exact in accessing (non-)squeezed items. # Also see https://github.com/pymc-devs/pymc/issues/6207. warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten()) + assert warn_objs + if isinstance(warn_objs[0], np.ndarray): + # Squeeze warning stats. See https://github.com/pymc-devs/pymc/issues/6207 + warn_objs = [a.tolist() for a in warn_objs] assert any(isinstance(w, SamplerWarning) for w in warn_objs) assert any("Asteroid" in w.message for w in warn_objs) else: From b12fe67fbad376b97718273f9c6f37f90a196073 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 5 Feb 2023 17:32:49 +0100 Subject: [PATCH 3/4] Test McBackend support --- conda-envs/environment-dev.yml | 1 + conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-dev.yml | 1 + conda-envs/windows-environment-test.yml | 1 + requirements-dev.txt | 1 + 5 files changed, 5 insertions(+) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 44c5d73ecc..80d4288db1 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -41,3 +41,4 @@ dependencies: - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme - numdifftools>=0.9.40 + - mcbackend>=0.3.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index b70ad4cee3..4e8a9cd070 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -31,3 +31,4 @@ dependencies: - types-cachetools - pip: - numdifftools>=0.9.40 + - mcbackend>=0.3.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 5b6e5749e5..101dbabc14 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -38,3 +38,4 @@ dependencies: - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme - numdifftools>=0.9.40 + - mcbackend>=0.3.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fd88e054ce..f7c639eae0 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -31,3 +31,4 @@ dependencies: - types-cachetools - pip: - numdifftools>=0.9.40 + - mcbackend>=0.3.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 4d3ba57f9b..f25ca7b908 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ git+https://github.com/pymc-devs/pymc-sphinx-theme h5py>=2.7 ipython>=7.16 jupyter-sphinx +mcbackend>=0.3.0 mypy==0.990 myst-nb numdifftools>=0.9.40 From 19df6205182295a5b1700fde5dd843530c9c14cf Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 5 Feb 2023 17:33:03 +0100 Subject: [PATCH 4/4] Default to using McBackend --- pymc/backends/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 61ec861f44..18cef87452 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -73,7 +73,7 @@ HAS_MCB = False try: - from mcbackend import Backend + from mcbackend import Backend, NumPyBackend from pymc.backends.mcbackend import ChainRecordAdapter, make_runmeta @@ -120,6 +120,8 @@ def init_traces( model: Model, ) -> Sequence[IBaseTrace]: """Initializes a trace recorder for each chain.""" + if HAS_MCB and backend is None: + backend = NumPyBackend(preallocate=expected_length) if HAS_MCB and isinstance(backend, Backend): run = backend.init_run( make_runmeta(