From 131df712fe8c04cc34c0d38763a8e8f5a2164120 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 25 Sep 2024 11:53:58 +0200 Subject: [PATCH 1/7] Fix dangling step in test_population --- tests/sampling/test_population.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sampling/test_population.py b/tests/sampling/test_population.py index 1f145dbcaf..4e3d91bcbb 100644 --- a/tests/sampling/test_population.py +++ b/tests/sampling/test_population.py @@ -65,7 +65,7 @@ def test_nonparallelized_chains_are_random(self): cores=1, draws=20, tune=0, - step=DEMetropolis(), + step=step, compute_convergence_checks=False, ) samples = idata.posterior["x"].values[:, 5] @@ -82,7 +82,7 @@ def test_parallelized_chains_are_random(self): cores=4, draws=20, tune=0, - step=DEMetropolis(), + step=step, compute_convergence_checks=False, ) samples = idata.posterior["x"].values[:, 5] From 1c54e45aec273857ad33f78d4528488b3cd8a3ea Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Mon, 23 Sep 2024 11:48:03 +0200 Subject: [PATCH 2/7] Detach step methods from numpy global random state --- pymc/math.py | 4 +- pymc/sampling/mcmc.py | 71 +++++++++++++--------- pymc/sampling/parallel.py | 28 +++++---- pymc/sampling/population.py | 25 ++++---- pymc/step_methods/arraystep.py | 43 ++++++++++---- pymc/step_methods/compound.py | 10 +++- pymc/step_methods/hmc/base_hmc.py | 20 +++++-- pymc/step_methods/hmc/hmc.py | 14 ++++- pymc/step_methods/hmc/nuts.py | 18 ++++-- pymc/step_methods/hmc/quadpotential.py | 80 +++++++++++++++++++------ pymc/step_methods/metropolis.py | 82 ++++++++++++++++---------- pymc/step_methods/slicer.py | 19 +++--- pymc/util.py | 64 ++++++++++++++++++-- tests/sampling/test_forward.py | 22 +++---- tests/sampling/test_parallel.py | 4 +- tests/step_methods/hmc/test_nuts.py | 9 ++- tests/step_methods/test_metropolis.py | 27 +++++---- tests/step_methods/test_slicer.py | 4 ++ 18 files changed, 379 insertions(+), 165 deletions(-) diff --git a/pymc/math.py b/pymc/math.py index b85ffe63ce..b5fc50a8eb 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -292,10 +292,10 @@ def logdiffexp_numpy(a, b): invlogit = sigmoid -def logbern(log_p): +def logbern(log_p, rng=None): if np.isnan(log_p): raise FloatingPointError("log_p can't be nan.") - return np.log(np.random.uniform()) < log_p + return np.log((rng or np.random).uniform()) < log_p def logit(p): diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 32d2702ff2..02143ede05 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -71,6 +71,7 @@ _get_seeds_per_chain, default_progress_theme, drop_warning_stat, + get_random_generator, get_untransformed_name, is_transformed_name, ) @@ -489,10 +490,15 @@ def sample( cores : int The number of chains to run in parallel. If ``None``, set to the number of CPUs in the system, but at most 4. - random_seed : int, array-like of int, RandomState or Generator, optional - Random seed(s) used by the sampling steps. If a list, tuple or array of ints - is passed, each entry will be used to seed each chain. A ValueError will be - raised if the length does not match the number of chains. + random_seed : int, array-like of int, or Generator, optional + Random seed(s) used by the sampling steps. Each step will create its own + :py:class:`~numpy.random.Generator` object to make its random draws in a way that is + indepedent from all other steppers and all other chains. If a list, tuple or array of ints + is passed, each entry will be used to seed the creation of ``Generator`` objects. + A ``ValueError`` will be raised if the length does not match the number of chains. + A ``TypeError`` will be raised if a :py:class:`~numpy.random.RandomState` object is passed. + We no longer support ``RandomState`` objects because their seeding mechanism does not allow + easy spawning of new independent random streams that are needed by the step methods. progressbar : bool, optional default=True Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining @@ -684,7 +690,8 @@ def joined_blas_limiter(): if random_seed == -1: random_seed = None - random_seed_list = _get_seeds_per_chain(random_seed, chains) + rngs = get_random_generator(random_seed).spawn(chains) + random_seed_list = [rng.integers(2**30) for rng in rngs] if not discard_tuned_samples and not return_inferencedata: warnings.warn( @@ -832,11 +839,11 @@ def joined_blas_limiter(): if parallel: # For parallel sampling we can pass the list of random seeds directly, as # global seeding will only be called inside each process - sample_args["random_seed"] = random_seed_list + sample_args["rngs"] = rngs else: # We pass None if the original random seed was None. The single core sampler # methods will only set a global seed when it is not None. - sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list + sample_args["rngs"] = rngs t_start = time.time() if parallel: @@ -987,7 +994,7 @@ def _sample_many( chains: int, traces: Sequence[IBaseTrace], start: Sequence[PointType], - random_seed: Sequence[RandomSeed] | None, + rngs: Sequence[np.random.Generator], step: Step, callback: SamplingIteratorCallback | None = None, **kwargs, @@ -1002,8 +1009,8 @@ def _sample_many( Total number of chains to sample. start: list Starting points for each chain - random_seed: list of random seeds, optional - A list of seeds, one for each chain + rngs: list of random Generators + A list of :py:class:`~numpy.random.Generator` objects, one for each chain step: function Step function """ @@ -1014,7 +1021,7 @@ def _sample_many( start=start[i], step=step, trace=traces[i], - random_seed=None if random_seed is None else random_seed[i], + rng=rngs[i], callback=callback, **kwargs, ) @@ -1025,7 +1032,7 @@ def _sample( *, chain: int, progressbar: bool, - random_seed: RandomSeed, + rng: np.random.Generator, start: PointType, draws: int, step: Step, @@ -1073,7 +1080,7 @@ def _sample( chain=chain, tune=tune, model=model, - random_seed=random_seed, + rng=rng, callback=callback, ) _pbar_data = {"chain": chain, "divergences": 0} @@ -1112,8 +1119,8 @@ def _iter_sample( trace: IBaseTrace, chain: int = 0, tune: int = 0, + rng: np.random.Generator, model: Model | None = None, - random_seed: RandomSeed = None, callback: SamplingIteratorCallback | None = None, ) -> Iterator[bool]: """Generator for sampling one chain. (Used in singleprocess sampling.) @@ -1147,8 +1154,7 @@ def _iter_sample( if draws < 1: raise ValueError("Argument `draws` must be greater than 0.") - if random_seed is not None: - np.random.seed(random_seed) + step.set_rng(rng) point = start @@ -1191,7 +1197,7 @@ def _mp_sample( step, chains: int, cores: int, - random_seed: Sequence[RandomSeed], + rngs: Sequence[np.random.Generator], start: Sequence[PointType], progressbar: bool = True, progressbar_theme: Theme | None = default_progress_theme, @@ -1216,8 +1222,8 @@ def _mp_sample( The number of chains to sample. cores : int The number of chains to run in parallel. - random_seed : list of random seeds - Random seeds for each chain. + rngs: list of random Generators + A list of :py:class:`~numpy.random.Generator` objects, one for each chain start : list Starting points for each chain. Dicts must contain numeric (transformed) initial values for all (transformed) free variables. @@ -1245,7 +1251,7 @@ def _mp_sample( tune=tune, chains=chains, cores=cores, - seeds=random_seed, + rngs=rngs, start_points=start, step_method=step, progressbar=progressbar, @@ -1444,12 +1450,12 @@ def init_nuts( mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) - potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) + potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0]) elif init == "jitter+adapt_diag": mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) - potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) + potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0]) elif init == "jitter+adapt_diag_grad": mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) @@ -1466,6 +1472,7 @@ def init_nuts( alpha=0.02, use_grads=True, stop_adaptation=stop_adaptation, + rng=random_seed_list[0], ) elif init == "advi+adapt_diag": approx = pm.fit( @@ -1486,7 +1493,9 @@ def init_nuts( mean = approx.mean.get_value() weight = 50 n = len(cov) - potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight) + potential = quadpotential.QuadPotentialDiagAdapt( + n, mean, cov, weight, rng=random_seed_list[0] + ) elif init == "advi": approx = pm.fit( random_seed=random_seed_list[0], @@ -1502,7 +1511,7 @@ def init_nuts( ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 - potential = quadpotential.QuadPotentialDiag(cov) + potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0]) elif init == "advi_map": start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0]) approx = pm.MeanField(model=model, start=start) @@ -1519,28 +1528,32 @@ def init_nuts( ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 - potential = quadpotential.QuadPotentialDiag(cov) + potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0]) elif init == "map": start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0]) cov = -pm.find_hessian(point=start, negate_output=False) initial_points = [start] * chains - potential = quadpotential.QuadPotentialFull(cov) + potential = quadpotential.QuadPotentialFull(cov, rng=random_seed_list[0]) elif init == "adapt_full": mean = np.mean(apoints_data * chains, axis=0) initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) - potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) + potential = quadpotential.QuadPotentialFullAdapt( + initial_point_model_size, mean, cov, 10, rng=random_seed_list[0] + ) elif init == "jitter+adapt_full": mean = np.mean(apoints_data, axis=0) initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) - potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) + potential = quadpotential.QuadPotentialFullAdapt( + initial_point_model_size, mean, cov, 10, rng=random_seed_list[0] + ) else: raise ValueError(f"Unknown initializer: {init}.") - step = pm.NUTS(potential=potential, model=model, **kwargs) + step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs) # Filter deterministics from initial_points value_var_names = [var.name for var in model.value_vars] diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index a34947c706..4b76e53a97 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,7 +33,7 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import CustomProgress, RandomSeed, default_progress_theme +from pymc.util import CustomProgress, default_progress_theme logger = logging.getLogger(__name__) @@ -93,15 +93,18 @@ def __init__( shared_point, draws: int, tune: int, - seed, + rng: np.random.Generator, + seed_seq: np.random.SeedSequence, blas_cores, ): + # For some strange reason, spawn multiprocessing doesn't copy the rng + # seed sequence, so we have to rebuild it from scratch + rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled self._shared_point = shared_point - self._seed = seed - self._at_seed = seed + 1 + self._rng = rng self._draws = draws self._tune = tune self._blas_cores = blas_cores @@ -159,7 +162,7 @@ def _recv_msg(self): return self._msg_pipe.recv() def _start_loop(self): - np.random.seed(self._seed) + self._step_method.set_rng(self._rng) draw = 0 tuning = True @@ -210,7 +213,7 @@ def __init__( step_method, step_method_pickled, chain: int, - seed, + rng: np.random.Generator, start: dict[str, np.ndarray], blas_cores, mp_ctx, @@ -260,7 +263,8 @@ def __init__( self._shared_point, draws, tune, - seed, + rng, + rng.bit_generator.seed_seq, blas_cores, ), ) @@ -379,7 +383,7 @@ def __init__( tune: int, chains: int, cores: int, - seeds: Sequence["RandomSeed"], + rngs: Sequence[np.random.Generator], start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, @@ -387,8 +391,8 @@ def __init__( blas_cores: int | None = None, mp_ctx=None, ): - if any(len(arg) != chains for arg in [seeds, start_points]): - raise ValueError(f"Number of seeds and start_points must be {chains}.") + if any(len(arg) != chains for arg in [rngs, start_points]): + raise ValueError(f"Number of rngs and start_points must be {chains}.") if mp_ctx is None or isinstance(mp_ctx, str): # Closes issue https://github.com/pymc-devs/pymc/issues/3849 @@ -416,12 +420,12 @@ def __init__( step_method, step_method_pickled, chain, - seed, + rng, start, blas_cores, mp_ctx, ) - for chain, seed, start in zip(range(chains), seeds, start_points) + for chain, rng, start in zip(range(chains), rngs, start_points) ] self._inactive = self._samplers.copy() diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 4d5ced3f52..c0dc813b5c 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -37,7 +37,7 @@ StatsType, ) from pymc.step_methods.metropolis import DEMetropolis -from pymc.util import CustomProgress, RandomSeed +from pymc.util import CustomProgress __all__ = () @@ -53,7 +53,7 @@ def _sample_population( initial_points: Sequence[PointType], draws: int, start: Sequence[PointType], - random_seed: RandomSeed, + rngs: Sequence[np.random.Generator], step: BlockedStep | CompoundStep, tune: int, model: Model, @@ -70,7 +70,8 @@ def _sample_population( The number of samples to draw start : list Start points for each chain - random_seed : single random seed, optional + rngs: sequence of random Generators + A list of :py:class:`~numpy.random.Generator` objects, one for each chain step : function Step function (should be or contain a population step method) tune : int @@ -96,7 +97,7 @@ def _sample_population( traces=traces, tune=tune, model=model, - random_seed=random_seed, + rngs=rngs, progressbar=progressbar, ) @@ -248,8 +249,6 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress): progress : progress.Progress The progress bar """ - # re-seed each child process to make them unique - np.random.seed(None) try: stepper = cloudpickle.loads(stepper_dumps) # the stepper is not necessarily a PopulationArraySharedStep itself, @@ -317,8 +316,8 @@ def _prepare_iter_population( parallelize: bool, traces: Sequence[BaseTrace], tune: int, + rngs: Sequence[np.random.Generator], model=None, - random_seed: RandomSeed = None, progressbar=True, ) -> Iterator[int]: """Prepare a PopulationStepper and traces for population sampling. @@ -335,8 +334,9 @@ def _prepare_iter_population( Setting for multiprocess parallelization tune : int Number of iterations to tune. + rngs: sequence of random Generators + A list of :py:class:`~numpy.random.Generator` objects, one for each chain model : Model (optional if in ``with`` context) - random_seed : single random seed, optional progressbar : bool ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) @@ -352,9 +352,6 @@ def _prepare_iter_population( if draws < 1: raise ValueError("Argument `draws` should be above 0.") - if random_seed is not None: - np.random.seed(random_seed) - # The initialization of traces, samplers and points must happen in the right order: # 1. population of points is created # 2. steppers are initialized and linked to the points object @@ -366,13 +363,17 @@ def _prepare_iter_population( # 2. Set up the steppers steppers: list[Step] = [] - for c in range(nchains): + assert ( + len(rngs) == nchains + ), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}" + for c, rng in enumerate(rngs): # need independent samplers for each chain # it is important to copy the actual steppers (but not the delta_logp) if isinstance(step, CompoundStep): chainstep = CompoundStep([copy(m) for m in step.methods]) else: chainstep = copy(step) + chainstep.set_rng(rng) # link population samplers to the shared population state for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: if isinstance(sm, PopulationArrayStepShared): diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index ca6036ecc6..602dfd6e51 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -18,12 +18,10 @@ import numpy as np -from numpy.random import uniform - from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType from pymc.model import modelcontext from pymc.step_methods.compound import BlockedStep -from pymc.util import get_var_name +from pymc.util import RandomGenerator, get_random_generator, get_var_name __all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"] @@ -39,13 +37,18 @@ class ArrayStep(BlockedStep): fs: list of logp PyTensor functions allvars: Boolean (default False) blocked: Boolean (default True) + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ - def __init__(self, vars, fs, allvars=False, blocked=True): + def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None): self.vars = vars self.fs = fs self.allvars = allvars self.blocked = blocked + self.rng = get_random_generator(rng) def step(self, point: PointType) -> tuple[PointType, StatsType]: partial_funcs_and_point: list[Callable | PointType] = [ @@ -79,17 +82,22 @@ class ArrayStepShared(BlockedStep): and unmapping overhead as well as moving fewer variables around. """ - def __init__(self, vars, shared, blocked=True): + def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): """ Parameters ---------- vars: list of sampling value variables shared: dict of PyTensor variable -> shared variable blocked: Boolean (default True) + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ self.vars = vars self.shared = {get_var_name(var): shared for var, shared in shared.items()} self.blocked = blocked + self.rng = get_random_generator(rng) def step(self, point: PointType) -> tuple[PointType, StatsType]: for name, shared_var in self.shared.items(): @@ -120,13 +128,17 @@ class PopulationArrayStepShared(ArrayStepShared): Works by linking a list of Points that is updated as the chains are iterated. """ - def __init__(self, vars, shared, blocked=True): + def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): """ Parameters ---------- vars: list of sampling value variables shared: dict of PyTensor variable -> shared variable blocked: Boolean (default True) + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ self.population = None self.this_chain = None @@ -155,7 +167,14 @@ def link_population(self, population, chain_index): class GradientSharedStep(ArrayStepShared): def __init__( - self, vars, model=None, blocked=True, dtype=None, logp_dlogp_func=None, **pytensor_kwargs + self, + vars, + model=None, + blocked=True, + dtype=None, + logp_dlogp_func=None, + rng: RandomGenerator = None, + **pytensor_kwargs, ): model = modelcontext(model) @@ -166,14 +185,16 @@ def __init__( self._logp_dlogp_func = func - super().__init__(vars, func._extra_vars_shared, blocked) + super().__init__(vars, func._extra_vars_shared, blocked, rng=rng) def step(self, point) -> tuple[PointType, StatsType]: self._logp_dlogp_func._extra_are_set = True return super().step(point) -def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.ndarray, bool]: +def metrop_select( + mr: np.ndarray, q: np.ndarray, q0: np.ndarray, rng: np.random.Generator +) -> tuple[np.ndarray, bool]: """Perform rejection/acceptance step for Metropolis class samplers. Returns the new sample q if a uniform random number is less than the @@ -185,6 +206,8 @@ def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.nda mr: float, Metropolis acceptance rate q: proposed sample q0: current sample + rng: numpy.random.Generator + A random number generator object Returns ------- @@ -193,7 +216,7 @@ def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.nda # Compare acceptance ratio to uniform random number # TODO XXX: This `uniform` is not given a model-specific RNG state, which # means that sampler runs that use it will not be reproducible. - if np.isfinite(mr) and np.log(uniform()) < mr: + if np.isfinite(mr) and np.log(rng.uniform()) < mr: return q, True else: return q0, False diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 7c0d8563ca..1c1d6fbb50 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -31,6 +31,7 @@ from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext +from pymc.util import get_random_generator __all__ = ("Competence", "CompoundStep") @@ -143,15 +144,18 @@ def __new__(cls, *args, **kwargs): # In this case we create a separate sampler for each var # and append them to a CompoundStep steps = [] - for var in vars: + rngs = get_random_generator(kwargs.pop("rng", None)).spawn(len(vars)) + for var, rng in zip(vars, rngs): step = super().__new__(cls) step.stats_dtypes = stats_dtypes step.stats_dtypes_shapes = stats_dtypes_shapes # If we don't return the instance we have to manually # call __init__ - step.__init__([var], *args, **kwargs) + _kwargs = kwargs.copy() + _kwargs["rng"] = rng + step.__init__([var], *args, **_kwargs) # Hack for creating the class correctly when unpickling. - step.__newargs = ([var], *args), kwargs + step.__newargs = ([var], *args), _kwargs steps.append(step) return CompoundStep(steps) diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index def6829d26..b320ed8194 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -29,6 +29,7 @@ from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep +from pymc.step_methods.compound import StepMethodState from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential @@ -75,6 +76,7 @@ def __init__( t0=10, adapt_step_size=True, step_rand=None, + rng=None, **pytensor_kwargs, ): """Set up Hamiltonian samplers with common structures. @@ -98,6 +100,14 @@ def __init__( potential: Potential, optional An object that represents the Hamiltonian with methods `velocity`, `energy`, and `random` methods. + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. The + resulting ``Generator`` object will be used stored in the step method + and used for accept/reject random selections. The step's ``Generator`` + will also be used to spawn independent ``Generators`` that will be used + by the ``potential`` attribute. **pytensor_kwargs: passed to PyTensor functions """ self._model = modelcontext(model) @@ -106,7 +116,9 @@ def __init__( vars = self._model.continuous_value_vars else: vars = get_value_vars_from_user_vars(vars, self._model) - super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs) + super().__init__( + vars, blocked=blocked, model=self._model, dtype=dtype, rng=rng, **pytensor_kwargs + ) self.adapt_step_size = adapt_step_size self.Emax = Emax @@ -131,7 +143,7 @@ def __init__( if scaling is None and potential is None: mean = floatX(np.zeros(size)) var = floatX(np.ones(size)) - potential = QuadPotentialDiagAdapt(size, mean, var, 10) + potential = QuadPotentialDiagAdapt(size, mean, var, 10, rng=self.rng.spawn(1)[0]) if isinstance(scaling, dict): point = Point(scaling, model=self._model) @@ -143,7 +155,7 @@ def __init__( if potential is not None: self.potential = potential else: - self.potential = quad_potential(scaling, is_cov) + self.potential = quad_potential(scaling, is_cov, rng=self.rng.spawn(1)[0]) self.integrator = integration.CpuLeapfrogIntegrator(self.potential, self._logp_dlogp_func) @@ -193,7 +205,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.step_size = step_size if self._step_rand is not None: - step_size = self._step_rand(step_size) + step_size = self._step_rand(step_size, rng=self.rng) hmc_step = self._hamiltonian_step(start, p0.data, step_size) diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 3c43509883..106faee501 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -27,8 +27,8 @@ __all__ = ["HamiltonianMC"] -def unif(step_size, elow=0.85, ehigh=1.15): - return np.random.uniform(elow, ehigh) * step_size +def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = None): + return (rng or np.random).uniform(elow, ehigh) * step_size class HamiltonianMC(BaseHMC): @@ -113,6 +113,14 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs): The maximum number of leapfrog steps. model: pymc.Model The model + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. The + resulting ``Generator`` object will be used stored in the step method + and used for accept/reject random selections. The step's ``Generator`` + will also be used to spawn independent ``Generators`` that will be used + by the ``potential`` attribute. **kwargs: passed to BaseHMC """ kwargs.setdefault("step_rand", unif) @@ -151,7 +159,7 @@ def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData: accept_stat = min(1, np.exp(-energy_change)) - if div_info is not None or np.random.rand() >= accept_stat: + if div_info is not None or self.rng.random() >= accept_stat: end = start accepted = False else: diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 541303bdf3..3c4b4e6800 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -169,6 +169,14 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs) of the scaling matrix. model: pymc.Model The model + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. The + resulting ``Generator`` object will be used stored in the step method + and used for accept/reject random selections. The step's ``Generator`` + will also be used to spawn independent ``Generators`` that will be used + by the ``potential`` attribute. kwargs: passed to BaseHMC Notes @@ -189,11 +197,11 @@ def _hamiltonian_step(self, start, p0, step_size): else: max_treedepth = self.max_treedepth - tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax) + tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=self.rng) reached_max_treedepth = False for _ in range(max_treedepth): - direction = logbern(np.log(0.5)) * 2 - 1 + direction = logbern(np.log(0.5), rng=self.rng) * 2 - 1 divergence_info, turning = tree.extend(direction) if divergence_info or turning: @@ -233,6 +241,7 @@ def __init__( start: State, step_size: float, Emax: float, + rng: np.random.Generator, ): """Binary tree from the NUTS algorithm. @@ -254,6 +263,7 @@ def __init__( self.step_size = step_size self.Emax = Emax self.start_energy = start.energy + self.rng = rng self.left = self.right = start self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) @@ -302,7 +312,7 @@ def extend(self, direction): return diverging, turning size1, size2 = self.log_size, tree.log_size - if logbern(size2 - size1): + if logbern(size2 - size1, rng=self.rng): self.proposal = tree.proposal self.log_size = np.logaddexp(self.log_size, tree.log_size) @@ -390,7 +400,7 @@ def _build_subtree(self, left, depth, epsilon): turning = turning | turning1 | turning2 log_size = np.logaddexp(tree1.log_size, tree2.log_size) - if logbern(tree2.log_size - log_size): + if logbern(tree2.log_size - log_size, rng=self.rng): proposal = tree2.proposal else: proposal = tree1.proposal diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index 4f975ff95c..abddaaf35f 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -22,7 +22,6 @@ import pytensor import scipy.linalg -from numpy.random import normal from scipy.sparse import issparse from pymc.pytensorf import floatX @@ -38,7 +37,7 @@ ] -def quad_potential(C, is_cov): +def quad_potential(C, is_cov, rng=None): """ Compute a QuadPotential object from a scaling matrix. @@ -49,6 +48,10 @@ def quad_potential(C, is_cov): vector treated as diagonal matrix. is_cov: Boolean whether C is provided as a covariance matrix or hessian + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. Returns ------- @@ -58,21 +61,21 @@ def quad_potential(C, is_cov): if not chol_available: raise ImportError("Sparse mass matrices require scikits.sparse") elif is_cov: - return QuadPotentialSparse(C) + return QuadPotentialSparse(C, rng=rng) else: raise ValueError("Sparse precision matrices are not supported") partial_check_positive_definite(C) if C.ndim == 1: if is_cov: - return QuadPotentialDiag(C) + return QuadPotentialDiag(C, rng=rng) else: - return QuadPotentialDiag(1.0 / C) + return QuadPotentialDiag(1.0 / C, rng=rng) else: if is_cov: - return QuadPotentialFull(C) + return QuadPotentialFull(C, rng=rng) else: - return QuadPotentialFullInv(C) + return QuadPotentialFullInv(C, rng=rng) def partial_check_positive_definite(C): @@ -100,6 +103,9 @@ def __str__(self): class QuadPotential: dtype: np.dtype + def __init__(self, rng=None): + self.rng = np.random.default_rng(rng) + @overload def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ... @@ -172,6 +178,7 @@ def __init__( discard_window=50, early_update=False, store_mass_matrix_trace=False, + rng=None, ): """Set up a diagonal mass matrix. @@ -202,6 +209,8 @@ def __init__( store_mass_matrix_trace : bool If true, store the mass matrix at each step of the adaptation. Only for debugging purposes. + rng : Generator | int | None + Numpy random number generator """ if initial_diag is not None and initial_diag.ndim != 1: raise ValueError("Initial diagonal must be one-dimensional.") @@ -234,6 +243,8 @@ def __init__( self._store_mass_matrix_trace = store_mass_matrix_trace self._mass_trace = [] + super().__init__(rng=rng) + self.reset() def reset(self): @@ -264,7 +275,7 @@ def velocity_energy(self, x, v_out): def random(self): """Draw random value from QuadPotential.""" - vals = normal(size=self._n).astype(self.dtype) + vals = self.rng.normal(size=self._n).astype(self.dtype) return self._inv_stds * vals def _update_from_weightvar(self, weightvar): @@ -405,7 +416,7 @@ def current_mean(self, out=None): class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt): - def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs): + def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs): """Set up a diagonal mass matrix. Parameters @@ -430,11 +441,15 @@ def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs store_mass_matrix_trace : bool If true, store the mass matrix at each step of the adaptation. Only for debugging purposes. + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ if len(args) > 3: raise ValueError("Unsupported arguments to QuadPotentialDiagAdaptExp") - super().__init__(*args, **kwargs) + super().__init__(*args, rng=rng, **kwargs) self._alpha = alpha self._use_grads = use_grads @@ -488,13 +503,19 @@ def _update_from_variances(self, var_estimator, inv_var_estimator): class QuadPotentialDiag(QuadPotential): """Quad potential using a diagonal covariance matrix.""" - def __init__(self, v, dtype=None): + def __init__(self, v, dtype=None, rng=None): """Use a vector to represent a diagonal matrix for a covariance matrix. Parameters ---------- v: vector, 0 <= ndim <= 1 Diagonal of covariance matrix for the potential vector + dtype : + The dtype to assign to the resulting momentum + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ if dtype is None: dtype = pytensor.config.floatX @@ -505,6 +526,7 @@ def __init__(self, v, dtype=None): self.s = s self.inv_s = 1.0 / s self.v = v + self.rng = np.random.default_rng(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -515,7 +537,7 @@ def velocity(self, x, out=None): def random(self): """Draw random value from QuadPotential.""" - return floatX(normal(size=self.s.shape)) * self.inv_s + return floatX(self.rng.normal(size=self.s.shape)) * self.inv_s def energy(self, x, velocity=None): """Compute kinetic energy at a position in parameter space.""" @@ -532,18 +554,25 @@ def velocity_energy(self, x, v_out): class QuadPotentialFullInv(QuadPotential): """QuadPotential object for Hamiltonian calculations using inverse of covariance matrix.""" - def __init__(self, A, dtype=None): + def __init__(self, A, dtype=None, rng=None): """Compute the lower cholesky decomposition of the potential. Parameters ---------- A: matrix, ndim = 2 Inverse of covariance matrix for the potential vector + dtype : + The dtype to assign to the resulting momentum + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ if dtype is None: dtype = pytensor.config.floatX self.dtype = dtype self.L = floatX(scipy.linalg.cholesky(A, lower=True)) + self.rng = np.random.default_rng(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -554,7 +583,7 @@ def velocity(self, x, out=None): def random(self): """Draw random value from QuadPotential.""" - n = floatX(normal(size=self.L.shape[0])) + n = floatX(self.rng.normal(size=self.L.shape[0])) return np.dot(self.L, n) def energy(self, x, velocity=None): @@ -572,13 +601,19 @@ def velocity_energy(self, x, v_out): class QuadPotentialFull(QuadPotential): """Basic QuadPotential object for Hamiltonian calculations.""" - def __init__(self, cov, dtype=None): + def __init__(self, cov, dtype=None, rng=None): """Compute the lower cholesky decomposition of the potential. Parameters ---------- A: matrix, ndim = 2 scaling matrix for the potential vector + dtype : + The dtype to assign to the resulting momentum + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ if dtype is None: dtype = pytensor.config.floatX @@ -586,6 +621,7 @@ def __init__(self, cov, dtype=None): self._cov = np.array(cov, dtype=self.dtype, copy=True) self._chol = scipy.linalg.cholesky(self._cov, lower=True) self._n = len(self._cov) + self.rng = np.random.default_rng(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -593,7 +629,7 @@ def velocity(self, x, out=None): def random(self): """Draw random value from QuadPotential.""" - vals = np.random.normal(size=self._n).astype(self.dtype) + vals = self.rng.normal(size=self._n).astype(self.dtype) return scipy.linalg.solve_triangular(self._chol.T, vals, overwrite_b=True) def energy(self, x, velocity=None): @@ -623,6 +659,7 @@ def __init__( adaptation_window_multiplier=2, update_window=1, dtype=None, + rng=None, ): warnings.warn("QuadPotentialFullAdapt is an experimental feature") @@ -652,6 +689,8 @@ def __init__( self.adaptation_window_multiplier = float(adaptation_window_multiplier) self._update_window = int(update_window) + self.rng = np.random.default_rng(rng) + self.reset() def reset(self): @@ -772,18 +811,23 @@ def current_mean(self): import pytensor.sparse class QuadPotentialSparse(QuadPotential): - def __init__(self, A): + def __init__(self, A, rng=None): """Compute a sparse cholesky decomposition of the potential. Parameters ---------- A: matrix, ndim = 2 scaling matrix for the potential vector + rng : RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ self.A = A self.size = A.shape[0] self.factor = factor = cholmod.cholesky(A) self.d_sqrt = np.sqrt(factor.D()) + self.rng = np.random.default_rng(rng) def velocity(self, x): """Compute the current velocity at a position in parameter space.""" @@ -792,7 +836,7 @@ def velocity(self, x): def random(self): """Draw random value from QuadPotential.""" - n = floatX(normal(size=self.size)) + n = floatX(self.rng.normal(size=self.size)) n /= self.d_sqrt n = self.factor.solve_Lt(n) n = self.factor.apply_Pt(n) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index d752999ec1..aa5101dbb0 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -116,7 +116,6 @@ class Metropolis(ArrayStepShared): name = "metropolis" - default_blocked = False stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), @@ -134,6 +133,7 @@ def __init__( tune_interval=100, model=None, mode=None, + rng=None, **kwargs, ): """Create an instance of a Metropolis stepper @@ -157,6 +157,10 @@ def __init__( Optional model for sampling step. Defaults to None (taken from context). mode: string or `Mode` instance. compilation mode passed to PyTensor functions + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ model = pm.modelcontext(model) @@ -223,7 +227,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared) + super().__init__(vars, shared, rng=rng) def reset_tuning(self): """Resets the tuned sampler parameters to their initial values.""" @@ -243,7 +247,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune = self.tune_interval self.accepted_sum[:] = 0 - delta = self.proposal_dist() * self.scaling + delta = self.proposal_dist(rng=self.rng) * self.scaling if self.any_discrete: if self.all_discrete: @@ -260,11 +264,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: q0d = q0d.copy() q_temp = q0d.copy() # Shuffle order of updates (probably we don't need to do this in every step) - np.random.shuffle(self.enum_dims) + self.rng.shuffle(self.enum_dims) for i in self.enum_dims: q_temp[i] = q[i] accept_rate_i = self.delta_logp(q_temp, q0d) - q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d) + q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d, rng=self.rng) q_temp[i] = q0d[i] = q_temp_[i] self.accept_rate_iter[i] = accept_rate_i self.accepted_iter[i] = accepted_i @@ -272,7 +276,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: q = q_temp else: accept_rate = self.delta_logp(q, q0d) - q, accepted = metrop_select(accept_rate, q, q0d) + q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng) self.accept_rate_iter = accept_rate self.accepted_iter = accepted self.accepted_sum += accepted @@ -357,7 +361,10 @@ class BinaryMetropolis(ArrayStep): The frequency of tuning. Defaults to 100 iterations. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). - + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ name = "binary_metropolis" @@ -393,7 +400,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: # Convert adaptive_scale_factor to a jump probability p_jump = 1.0 - 0.5**self.scaling - rand_array = nr.random(q0.shape) + rand_array = self.rng.random(q0.shape) q = np.copy(q0) # Locations where switches occur, according to p_jump switch_locs = rand_array < p_jump @@ -401,7 +408,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp_q = logp(RaveledVars(q, point_map_info)) accept = logp_q - logp_q0 - q_new, accepted = metrop_select(accept, q, q0) + q_new, accepted = metrop_select(accept, q, q0, rng=self.rng) self.accepted += accepted stats = { @@ -453,7 +460,10 @@ class BinaryGibbsMetropolis(ArrayStep): which resulting in more efficient antithetical sampling. Default is 0.8 model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). - + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ name = "binary_gibbs_metropolis" @@ -498,7 +508,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp: Callable[[RaveledVars], np.ndarray] = args[0] order = self.order if self.shuffle_dims: - nr.shuffle(order) + self.rng.shuffle(order) q = RaveledVars(np.copy(apoint.data), apoint.point_map_info) @@ -507,10 +517,12 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: for idx in order: # No need to do metropolis update if the same value is proposed, # as you will get the same value regardless of accepted or reject - if nr.rand() < self.transit_p: + if self.rng.random() < self.transit_p: curr_val, q.data[idx] = q.data[idx], True - q.data[idx] logp_prop = logp(q) - q.data[idx], accepted = metrop_select(logp_prop - logp_curr, q.data[idx], curr_val) + q.data[idx], accepted = metrop_select( + logp_prop - logp_curr, q.data[idx], curr_val, rng=self.rng + ) if accepted: logp_curr = logp_prop @@ -561,7 +573,7 @@ class CategoricalGibbsMetropolis(ArrayStep): "tune": (bool, []), } - def __init__(self, vars, proposal="uniform", order="random", model=None): + def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None): model = pm.modelcontext(model) vars = get_value_vars_from_user_vars(vars, model) @@ -615,7 +627,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): # that indicates whether a draw was done in a tuning phase. self.tune = True - super().__init__(vars, [model.compile_logp()]) + super().__init__(vars, [model.compile_logp()], rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -628,15 +640,17 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType dimcats = self.dimcats if self.shuffle_dims: - nr.shuffle(dimcats) + self.rng.shuffle(dimcats) q = RaveledVars(np.copy(q0), point_map_info) logp_curr = logp(q) for dim, k in dimcats: - curr_val, q.data[dim] = q.data[dim], sample_except(k, q.data[dim]) + curr_val, q.data[dim] = q.data[dim], sample_except(k, q.data[dim], rng=self.rng) logp_prop = logp(q) - q.data[dim], accepted = metrop_select(logp_prop - logp_curr, q.data[dim], curr_val) + q.data[dim], accepted = metrop_select( + logp_prop - logp_curr, q.data[dim], curr_val, rng=self.rng + ) if accepted: logp_curr = logp_prop @@ -652,7 +666,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType dimcats = self.dimcats if self.shuffle_dims: - nr.shuffle(dimcats) + self.rng.shuffle(dimcats) q = RaveledVars(np.copy(q0), point_map_info) logp_curr = logp(q) @@ -677,9 +691,9 @@ def metropolis_proportional(self, q, logp, logp_curr, dim, k): probs = scipy.special.softmax(log_probs, axis=0) prob_curr, probs[given_cat] = probs[given_cat], 0.0 probs /= 1.0 - prob_curr - proposed_cat = nr.choice(candidates, p=probs) + proposed_cat = self.rng.choice(candidates, p=probs) accept_ratio = (1.0 - prob_curr) / (1.0 - probs[proposed_cat]) - if not np.isfinite(accept_ratio) or nr.uniform() >= accept_ratio: + if not np.isfinite(accept_ratio) or self.rng.uniform() >= accept_ratio: q.data[dim] = given_cat return logp_curr q.data[dim] = proposed_cat @@ -739,6 +753,10 @@ class DEMetropolis(PopulationArrayStepShared): Optional model for sampling step. Defaults to None (taken from context). mode: string or `Mode` instance. compilation mode passed to PyTensor functions + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. References ---------- @@ -821,7 +839,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune = self.tune_interval self.accepted = 0 - epsilon = self.proposal_dist() * self.scaling + epsilon = self.proposal_dist(rng=self.rng) * self.scaling # differential evolution proposal # select two other chains @@ -832,7 +850,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon) accept = self.delta_logp(q, q0d) - q_new, accepted = metrop_select(accept, q, q0d) + q_new, accepted = metrop_select(accept, q, q0d, rng=self.rng) self.accepted += accepted self.steps_until_tune -= 1 @@ -883,6 +901,10 @@ class DEMetropolisZ(ArrayStepShared): Optional model for sampling step. Defaults to None (taken from context). mode: string or `Mode` instance. compilation mode passed to PyTensor functions + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. References ---------- @@ -986,17 +1008,17 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune = self.tune_interval self.accepted = 0 - epsilon = self.proposal_dist() * self.scaling + epsilon = self.proposal_dist(rng=self.rng) * self.scaling it = len(self._history) # use the DE-MCMC-Z proposal scheme as soon as the history has 2 entries if it > 1: # differential evolution proposal # select two other chains - iz1 = np.random.randint(it) - iz2 = np.random.randint(it) + iz1 = self.rng.integers(it) + iz2 = self.rng.integers(it) while iz2 == iz1: - iz2 = np.random.randint(it) + iz2 = self.rng.integers(it) z1 = self._history[iz1] z2 = self._history[iz2] @@ -1007,7 +1029,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: q = floatX(q0d + epsilon) accept = self.delta_logp(q, q0d) - q_new, accepted = metrop_select(accept, q, q0d) + q_new, accepted = metrop_select(accept, q, q0d, rng=self.rng) self.accepted += accepted self._history.append(q_new) @@ -1039,8 +1061,8 @@ def competence(var, has_grad): return Competence.COMPATIBLE -def sample_except(limit, excluded): - candidate = nr.choice(limit - 1) +def sample_except(limit, excluded, rng: np.random.Generator): + candidate = rng.choice(limit - 1) if candidate >= excluded: candidate += 1 return candidate diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 3a9d90800a..3e096aeb9f 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -16,7 +16,6 @@ import numpy as np -import numpy.random as nr from pymc.blocking import RaveledVars, StatsType from pymc.model import modelcontext @@ -47,6 +46,10 @@ class Slice(ArrayStepShared): Optional model for sampling step. It will be taken from the context if not provided. iter_limit : int, default np.inf Maximum number of iterations for the slice sampler. + rng: RandomGenerator + An object that can produce be used to produce the step method's + :py:class:`~numpy.random.Generator` object. Refer to + :py:func:`pymc.util.get_random_generator` for more information. """ @@ -58,7 +61,9 @@ class Slice(ArrayStepShared): "nstep_in": (int, []), } - def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs): + def __init__( + self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs + ): model = modelcontext(model) self.w = np.asarray(w).copy() self.tune = tune @@ -78,7 +83,7 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, * self.logp = compile_pymc([raveled_inp], logp) self.logp.trust_input = True - super().__init__(vars, shared) + super().__init__(vars, shared, rng=rng) def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: # The arguments are determined by the list passed via `super().__init__(..., fs, ...)` @@ -96,10 +101,10 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: logp = self.logp for i, wi in enumerate(self.w): # uniformly sample from 0 to p(q), but in log space - y = logp(q) - nr.standard_exponential() + y = logp(q) - self.rng.standard_exponential() # Create initial interval - ql[i] = q[i] - nr.uniform() * wi # q[i] + r * w + ql[i] = q[i] - self.rng.uniform() * wi # q[i] + r * w qr[i] = ql[i] + wi # Equivalent to q[i] + (1-r) * w # Stepping out procedure @@ -120,14 +125,14 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: nstep_out += cnt cnt = 0 - q[i] = nr.uniform(ql[i], qr[i]) + q[i] = self.rng.uniform(ql[i], qr[i]) while y > logp(q): # Changed leq to lt, to accommodate for locally flat posteriors # Sample uniformly from slice if q[i] > q0_val[i]: qr[i] = q[i] elif q[i] < q0_val[i]: ql[i] = q[i] - q[i] = nr.uniform(ql[i], qr[i]) + q[i] = self.rng.uniform(ql[i], qr[i]) cnt += 1 if cnt > self.iter_limit: raise RuntimeError(LOOP_ERR_MSG % self.iter_limit) diff --git a/pymc/util.py b/pymc/util.py index fe55813385..7733d41b60 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Sequence +from copy import deepcopy from typing import NewType, cast import arviz @@ -399,6 +400,7 @@ def wrapped(**kwargs): RandomSeed = None | int | Sequence[int] | np.ndarray RandomState = RandomSeed | np.random.RandomState | np.random.Generator +RandomGenerator = RandomSeed | np.random.Generator | np.random.BitGenerator def _get_seeds_per_chain( @@ -431,10 +433,15 @@ def _get_unique_seeds_per_chain(integers_fn): seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] return seeds - if random_state is None or isinstance(random_state, int): - if chains == 1 and isinstance(random_state, int): - return (random_state,) - return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) + try: + int_random_state = int(random_state) # type: ignore + except Exception: + int_random_state = None + + if random_state is None or int_random_state is not None: + if chains == 1 and int_random_state is not None: + return (int_random_state,) + return _get_unique_seeds_per_chain(np.random.default_rng(int_random_state).integers) if isinstance(random_state, np.random.Generator): return _get_unique_seeds_per_chain(random_state.integers) if isinstance(random_state, np.random.RandomState): @@ -578,3 +585,52 @@ def update( **fields, ) return None + + +def get_random_generator( + seed: RandomGenerator | np.random.RandomState = None, copy: bool = True +) -> np.random.Generator: + """Build a :py:class:`~numpy.random.Generator` object from a suitable seed. + + Parameters + ---------- + seed : None | int | Sequence[int] | numpy.random.Generator | numpy.random.BitGenerator | numpy.random.RandomState + A suitable seed to use to generate the :py:class:`~numpy.random.Generator` object. + For more details on suitable seeds, refer to :py:func:`numpy.random.default_rng`. + copy : bool + Boolean flag that indicates whether to copy the seed object before feeding + it to :py:func:`numpy.random.default_rng`. If `copy` is `False`, and the seed + object is a ``BitGenerator`` or ``Generator`` object, the returned + ``Generator`` will use the ``seed`` object where possible. This means that it + will return the ``seed`` input object if it is a ``Generator`` or that it + will return a new ``Generator`` whose ``bit_generator`` attribute will be the + input ``seed`` object. To avoid this potential object sharing, you must set + ``copy`` to ``True``. + + Returns + ------- + rng : numpy.random.Generator + The result of passing the input ``seed`` (or a copy of it) through + :py:func:`numpy.random.default_rng`. + + Raises + ------ + TypeError: + If the supplied ``seed`` is a :py:class:`~numpy.random.RandomState` object. We + do not support using these legacy objects because their seeding strategy is not + amenable to spawning new independent random streams. + """ + if isinstance(seed, np.random.RandomState): + raise TypeError( + "Cannot create a random Generator from a RandomStream object. " + "Please provide a random seed, BitGenerator or Generator instead." + ) + if copy: + # If seed is a numpy.random.Generator or numpy.random.BitGenerator, + # numpy.random.default_rng will use the exact same object to return. + # In the former case, it will return seed, in the latter it will return + # a new Generator object that has the same BitGenerator. This would potentially + # make the new generator be shared across many users. To avoid this, we + # deepcopy by default. + seed = deepcopy(seed) + return np.random.default_rng(seed) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 24579bae02..dd408e5f86 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -497,7 +497,7 @@ def test_normal_scalar(self): assert ppc["a"].shape == (nchains, ndraws) # test default case - random_state = np.random.RandomState(20160911) + random_state = np.random.default_rng(20160911) idata_ppc = pm.sample_posterior_predictive( trace, var_names=["a"], random_seed=random_state ) @@ -623,9 +623,9 @@ def test_model_not_drawable_prior(self, seeded_test): assert samples["foo"].shape == (1, 40, 200) def test_model_shared_variable(self): - rng = np.random.RandomState(9832) + rng = np.random.default_rng(9832) - x = rng.randn(100) + x = rng.normal(size=100) y = x > 0 x_shared = pytensor.shared(x) y_shared = pytensor.shared(y) @@ -656,10 +656,10 @@ def test_model_shared_variable(self): npt.assert_allclose(post_pred["p"], expected_p) def test_deterministic_of_observed(self): - rng = np.random.RandomState(8442) + rng = np.random.default_rng(8442) - meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.randn(10)) - meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.randn(10)) + meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.normal(size=10)) + meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.normal(size=10)) nchains = 2 with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 2) @@ -696,10 +696,10 @@ def test_deterministic_of_observed(self): npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol) def test_deterministic_of_observed_modified_interface(self): - rng = np.random.RandomState(4982) + rng = np.random.default_rng(4982) - meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.randn(100)) - meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.randn(100)) + meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.normal(size=100)) + meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.normal(size=100)) with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0) sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1) @@ -1408,7 +1408,7 @@ def test_distinct_rvs(): Y_rv = pm.Normal("y") pp_samples = pm.sample_prior_predictive( - draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + draws=2, return_inferencedata=False, random_seed=npr.default_rng(2023532) ) assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] @@ -1418,7 +1418,7 @@ def test_distinct_rvs(): Y_rv = pm.Normal("y") pp_samples_2 = pm.sample_prior_predictive( - draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + draws=2, return_inferencedata=False, random_seed=npr.default_rng(2023532) ) assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index c69c75fabc..8c71bcac00 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -157,7 +157,7 @@ def test_explicit_sample(mp_start_method): 10, step, chain=3, - seed=1, + rng=np.random.default_rng(1), mp_ctx=ctx, start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}, step_method_pickled=step_method_pickled, @@ -190,7 +190,7 @@ def test_iterator(): tune=10, chains=3, cores=2, - seeds=[2, 3, 4], + rngs=np.random.default_rng(1).spawn(3), start_points=[start] * 3, step_method=step, progressbar=False, diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 1bec2d2f46..2bb71b893e 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -36,14 +36,15 @@ class TestNUTSUniform(sf.NutsFixture, sf.UniformFixture): min_n_eff = 9000 rtol = 0.1 atol = 0.05 + step_args = {"random_seed": 202010} class TestNUTSUniform2(TestNUTSUniform): - step_args = {"target_accept": 0.95} + step_args = {"target_accept": 0.95, "random_seed": 202010} class TestNUTSUniform3(TestNUTSUniform): - step_args = {"target_accept": 0.80} + step_args = {"target_accept": 0.80, "random_seed": 202010} class TestNUTSNormal(sf.NutsFixture, sf.NormalFixture): @@ -54,6 +55,7 @@ class TestNUTSNormal(sf.NutsFixture, sf.NormalFixture): min_n_eff = 10000 rtol = 0.1 atol = 0.05 + step_args = {"random_seed": 123456} class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture): @@ -63,6 +65,7 @@ class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture): burn = 0 chains = 2 min_n_eff = 400 + step_args = {"random_seed": 202010} class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture): @@ -73,6 +76,7 @@ class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture): min_n_eff = 1000 rtol = 0.1 atol = 0.05 + step_args = {"random_seed": 202010} @pytest.mark.skip("Takes too long to run") @@ -92,6 +96,7 @@ class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture): burn = 0 chains = 2 min_n_eff = 200 + step_args = {"random_seed": 202010} class TestNutsCheckTrace: diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 7bfdb645c7..f414a534e8 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -36,6 +36,8 @@ from tests.helpers import RVsAssignmentStepsTester, StepMethodTester from tests.models import mv_simple, mv_simple_discrete, simple_categorical +SEED = sum(ord(c) for c in "test_metropolis") + class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture): n_samples = 50000 @@ -45,6 +47,7 @@ class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture): min_n_eff = 10000 rtol = 0.1 atol = 0.05 + step_args = {"rng": np.random.default_rng(SEED)} class TestMetropolis: @@ -81,7 +84,7 @@ def test_tuning_reset(self): idata = pm.sample( tune=600, draws=500, - step=Metropolis(tune=True, scaling=0.1), + step=Metropolis(tune=True, scaling=0.1, rng=SEED), cores=1, chains=3, discard_tuned_samples=False, @@ -113,7 +116,7 @@ def test_tuning_reset(self): def test_elemwise_update(self, batched_dist): with pm.Model() as m: m.register_rv(batched_dist, name="batched_dist") - step = pm.Metropolis([batched_dist]) + step = pm.Metropolis([batched_dist], rng=SEED) assert step.elemwise_update == (batched_dist.ndim > 0) trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428) @@ -124,7 +127,7 @@ def test_elemwise_update_different_scales(self): mu = [1, 2, 3, 4, 5, 100, 1_000, 10_000] with pm.Model() as m: x = pm.Poisson("x", mu=mu) - step = pm.Metropolis([x]) + step = pm.Metropolis([x], rng=SEED) trace = pm.sample(draws=1000, chains=2, step=step, random_seed=128).posterior np.testing.assert_allclose(trace["x"].mean(("draw", "chain")), mu, rtol=0.1) @@ -134,7 +137,7 @@ def test_multinomial_no_elemwise_update(self): with pm.Model() as m: batched_dist = pm.Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4)) with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - step = pm.Metropolis([batched_dist]) + step = pm.Metropolis([batched_dist], rng=SEED) assert not step.elemwise_update @@ -167,7 +170,7 @@ def test_tuning_lambda_sequential(self): idata = pm.sample( tune=1000, draws=500, - step=DEMetropolisZ(tune="lambda", lamb=0.92), + step=DEMetropolisZ(tune="lambda", lamb=0.92, rng=SEED), cores=1, chains=3, discard_tuned_samples=False, @@ -185,7 +188,7 @@ def test_tuning_epsilon_parallel(self): idata = pm.sample( tune=1000, draws=500, - step=DEMetropolisZ(tune="scaling", scaling=0.002), + step=DEMetropolisZ(tune="scaling", scaling=0.002, rng=SEED), cores=2, chains=2, discard_tuned_samples=False, @@ -203,7 +206,7 @@ def test_tuning_none(self): idata = pm.sample( tune=1000, draws=500, - step=DEMetropolisZ(tune=None), + step=DEMetropolisZ(tune=None, rng=SEED), cores=1, chains=2, discard_tuned_samples=False, @@ -221,7 +224,7 @@ def test_tuning_reset(self): idata = pm.sample( tune=1000, draws=500, - step=DEMetropolisZ(tune="scaling", scaling=0.002), + step=DEMetropolisZ(tune="scaling", scaling=0.002, rng=SEED), cores=1, chains=3, discard_tuned_samples=False, @@ -245,7 +248,7 @@ def test_tune_drop_fraction(self): draws = 200 with pm.Model() as pmodel: pm.Normal("n", 0, 2, size=(3,)) - step = DEMetropolisZ(tune_drop_fraction=tune_drop_fraction) + step = DEMetropolisZ(tune_drop_fraction=tune_drop_fraction, rng=SEED) idata = pm.sample( tune=tune, draws=draws, step=step, cores=1, chains=1, discard_tuned_samples=False ) @@ -292,7 +295,7 @@ def test_step_discrete(self): unc = np.diag(C) ** 0.5 check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0)) with model: - step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal) + step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal, rng=123456) idata = pm.sample( tune=1000, draws=2000, @@ -311,7 +314,7 @@ def test_step_categorical(self, proposal): unc = C**0.5 check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0)) with model: - step = CategoricalGibbsMetropolis([model.x], proposal=proposal) + step = CategoricalGibbsMetropolis([model.x], proposal=proposal, rng=SEED) idata = pm.sample( tune=1000, draws=2000, @@ -329,7 +332,7 @@ def test_step_categorical(self, proposal): [ ( lambda C, _: Metropolis( - S=C, proposal_dist=MultivariateNormalProposal, blocked=True + S=C, proposal_dist=MultivariateNormalProposal, blocked=True, rng=SEED ), 4000, ), diff --git a/tests/step_methods/test_slicer.py b/tests/step_methods/test_slicer.py index 80435573c0..899d4ec9ec 100644 --- a/tests/step_methods/test_slicer.py +++ b/tests/step_methods/test_slicer.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest from pymc.step_methods.slicer import Slice from tests import sampler_fixtures as sf from tests.helpers import RVsAssignmentStepsTester, StepMethodTester +SEED = 20240920 + class TestSliceUniform(sf.SliceFixture, sf.UniformFixture): n_samples = 10000 @@ -27,6 +30,7 @@ class TestSliceUniform(sf.SliceFixture, sf.UniformFixture): min_n_eff = 5000 rtol = 0.1 atol = 0.05 + step_args = {"rng": np.random.default_rng(SEED)} class TestStepSlicer(StepMethodTester): From c399241d73e5f7eed193faebd24c6f3e0b979f54 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:53:06 +0200 Subject: [PATCH 3/7] Add sampling state base classes --- .github/workflows/tests.yml | 3 +- pymc/step_methods/state.py | 99 +++++++++++++++++++ tests/helpers.py | 19 ++++ tests/step_methods/test_state.py | 158 +++++++++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 pymc/step_methods/state.py create mode 100644 tests/step_methods/test_state.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ac9eff143..0956f17b60 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -103,6 +103,7 @@ jobs: tests/ode/test_ode.py tests/ode/test_utils.py tests/step_methods/hmc/test_quadpotential.py + tests/step_methods/test_state.py - | tests/backends/test_mcbackend.py @@ -197,7 +198,7 @@ jobs: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py - tests/model/test_core.py tests/sampling/test_mcmc.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py + - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py new file mode 100644 index 0000000000..9b85d7784b --- /dev/null +++ b/pymc/step_methods/state.py @@ -0,0 +1,99 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from dataclasses import Field, dataclass, fields +from typing import Any, ClassVar + +import numpy as np + +dataclass_state = dataclass(kw_only=True) + + +@dataclass_state +class DataClassState: + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] = {} + + +def equal_dataclass_values(v1, v2): + if v1.__class__ != v2.__class__: + return False + if isinstance(v1, (list, tuple)): # noqa: UP038 + return len(v1) == len(v2) and all( + equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) + ) + elif isinstance(v1, dict): + if set(v1) != set(v2): + return False + return all(equal_dataclass_values(v1[k], v2[k]) for k in v1) + elif isinstance(v1, np.ndarray): + return bool(np.array_equal(v1, v2, equal_nan=True)) + elif isinstance(v1, np.random.Generator): + return equal_dataclass_values(v1.bit_generator.state, v2.bit_generator.state) + elif isinstance(v1, DataClassState): + return set(fields(v1)) == set(fields(v2)) and all( + equal_dataclass_values(getattr(v1, f1.name), getattr(v2, f2.name)) + for f1, f2 in zip(fields(v1), fields(v2), strict=True) + ) + else: + return v1 == v2 + + +class WithSamplingState: + """Mixin class that adds the ``sampling_state`` property to an object. + + The object's type must define the ``_state_class`` as a valid + :py:class:`~pymc.step_method.DataClassState`. Once that happens, the + object's ``sampling_state`` property can be read or set to get + the state represented as objects of the ``_state_class`` type. + """ + + _state_class: type[DataClassState] = DataClassState + + @property + def sampling_state(self) -> DataClassState: + state_class = self._state_class + kwargs = {} + for field in fields(state_class): + val = getattr(self, field.name) + if isinstance(val, WithSamplingState): + _val = val.sampling_state + else: + _val = val + kwargs[field.name] = deepcopy(_val) + return state_class(**kwargs) + + @sampling_state.setter + def sampling_state(self, state: DataClassState): + state_class = self._state_class + assert isinstance( + state, state_class + ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" + for field in fields(state_class): + state_val = deepcopy(getattr(state, field.name)) + self_val = getattr(self, field.name) + is_frozen = field.metadata.get("frozen", False) + if is_frozen: + if not equal_dataclass_values(state_val, self_val): + raise ValueError( + "The received sampling state must have the same values for the " + f"frozen fields. Field {field.name!r} has different values. " + f"Expected {self_val} but got {state_val}" + ) + else: + if isinstance(state_val, DataClassState): + assert isinstance(self_val, WithSamplingState) + self_val.sampling_state = state_val + setattr(self, field.name, self_val) + else: + setattr(self, field.name, state_val) diff --git a/tests/helpers.py b/tests/helpers.py index c0f210bf8c..c14433711b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -17,6 +17,11 @@ import tempfile import warnings +<<<<<<< HEAD +======= +from copy import deepcopy +from dataclasses import fields +>>>>>>> 741b38626 (Fixup state) from logging.handlers import BufferingHandler import numpy as np @@ -28,6 +33,7 @@ import pymc as pm +from pymc.step_methods.state import equal_dataclass_values from pymc.testing import fast_unstable_sampling_mode from tests.models import mv_simple, mv_simple_coarse @@ -177,3 +183,16 @@ def continuous_steps(self, step, step_kwargs): assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( step([c1, c2], **step_kwargs).vars ) + + +def equal_sampling_states(this, other): + if this.__class__ != other.__class__: + return False + this_fields = set([f.name for f in fields(this)]) + other_fields = set([f.name for f in fields(other)]) + for field in this_fields: + this_val = getattr(this, field) + other_val = getattr(other, field) + if not equal_dataclass_values(this_val, other_val): + return False + return this_fields == other_fields diff --git a/tests/step_methods/test_state.py b/tests/step_methods/test_state.py new file mode 100644 index 0000000000..e6a39264db --- /dev/null +++ b/tests/step_methods/test_state.py @@ -0,0 +1,158 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import field + +import numpy as np +import pytest + +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from tests.helpers import equal_sampling_states + + +@dataclass_state +class State1(DataClassState): + a: int + b: float + c: str + d: np.ndarray + e: list + f: dict + + +@dataclass_state +class State2(DataClassState): + mutable_field: float + state1: State1 + extra_info1: np.ndarray = field(metadata={"frozen": True}) + extra_info2: list = field(metadata={"frozen": True}) + extra_info3: dict = field(metadata={"frozen": True}) + + +class A(WithSamplingState): + _state_class = State1 + + def __init__(self, a=1, b=2.0, c="c", d=None, e=None, f=None): + self.a = a + self.b = b + self.c = c + if d is None: + d = np.array([1, 2]) + if e is None: + e = [1, 2, 3] + if f is None: + f = {"a": 1, "b": "c"} + self.d = d + self.e = e + self.f = f + + +class B(WithSamplingState): + _state_class = State2 + + def __init__( + self, + a=1, + b=2.0, + c="c", + d=None, + e=None, + f=None, + mutable_field=1.0, + extra_info1=None, + extra_info2=None, + extra_info3=None, + ): + self.state1 = A(a=a, b=b, c=c, d=d, e=e, f=f) + self.mutable_field = mutable_field + if extra_info1 is None: + extra_info1 = np.array([3, 4, 5]) + if extra_info2 is None: + extra_info2 = [5, 6, 7] + if extra_info3 is None: + extra_info3 = {"foo": "bar"} + self.extra_info1 = extra_info1 + self.extra_info2 = extra_info2 + self.extra_info3 = extra_info3 + + +@dataclass_state +class RngState(DataClassState): + rng: np.random.Generator + + +class Step(WithSamplingState): + _state_class = RngState + + def __init__(self, rng=None): + self.rng = np.random.default_rng(rng) + + +def test_sampling_state(): + b1 = B() + b2 = B(mutable_field=2.0) + b3 = B(c=1, extra_info1=np.array([10, 20])) + b4 = B(a=2, b=3.0, c="d") + b5 = B(c=1) + b6 = B(f={"a": 1, "b": "c", "d": None}) + + b1_state = b1.sampling_state + b2_state = b2.sampling_state + b3_state = b3.sampling_state + b4_state = b4.sampling_state + + assert equal_sampling_states(b1_state.state1, b2_state.state1) + assert not equal_sampling_states(b1_state, b2_state) + assert not equal_sampling_states(b1_state, b3_state) + assert not equal_sampling_states(b1_state, b4_state) + + b1.sampling_state = b2_state + assert equal_sampling_states(b1.sampling_state, b2_state) + + expected_error_message = ( + "The received sampling state must have the same values for the " + "frozen fields. Field 'extra_info1' has different values. " + r"Expected \[3 4 5\] but got \[10 20\]" + ) + with pytest.raises(ValueError, match=expected_error_message): + b1.sampling_state = b3_state + + with pytest.raises(AssertionError, match="Encountered invalid state class"): + b1.sampling_state = b1_state.state1 + + b1.sampling_state = b4_state + assert equal_sampling_states(b1.sampling_state, b4_state) + assert not equal_sampling_states(b1.sampling_state, b5.sampling_state) + assert not equal_sampling_states(b1.sampling_state, b6.sampling_state) + + +@pytest.mark.parametrize( + "step", + [ + Step(), + Step(1), + Step(np.random.Generator(np.random.Philox(1))), + ], + ids=["default_rng", "default_rng(1)", "philox"], +) +def test_sampling_state_rng(step): + original_state = step.sampling_state + values1 = step.rng.random(100) + + final_state = step.sampling_state + assert not equal_sampling_states(original_state, final_state) + + step.sampling_state = original_state + values2 = step.rng.random(100) + assert np.array_equal(values1, values2, equal_nan=True) + assert equal_sampling_states(step.sampling_state, final_state) From 5f6ac334d11902c5678f5d8287bfe06666145bf7 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:54:29 +0200 Subject: [PATCH 4/7] Add step method state --- pymc/step_methods/arraystep.py | 4 ++-- pymc/step_methods/compound.py | 42 +++++++++++++++++++++++++++++++--- tests/helpers.py | 26 ++++++++++++++++++--- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 602dfd6e51..bddf02f155 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -142,8 +142,8 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): """ self.population = None self.this_chain = None - self.other_chains = None - return super().__init__(vars, shared, blocked) + self.other_chains: list[int] | None = None + return super().__init__(vars, shared, blocked, rng=rng) def link_population(self, population, chain_index): """Links the sampler to the population. diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 1c1d6fbb50..87dd30420a 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -31,7 +31,8 @@ from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext -from pymc.util import get_random_generator +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from pymc.util import RandomGenerator, get_random_generator __all__ = ("Competence", "CompoundStep") @@ -87,7 +88,12 @@ def infer_warn_stats_info( return stats_dtypes, sds -class BlockedStep(ABC): +@dataclass_state +class StepMethodState(DataClassState): + rng: np.random.Generator + + +class BlockedStep(ABC, WithSamplingState): stats_dtypes: list[dict[str, type]] = [] """A list containing <=1 dictionary that maps stat names to dtypes. @@ -195,6 +201,9 @@ def stop_tuning(self): if hasattr(self, "tune"): self.tune = False + def set_rng(self, rng: RandomGenerator): + self.rng = get_random_generator(rng, copy=False) + def flat_statname(sampler_idx: int, sname: str) -> str: """Get the flat-stats name for a samplers stat.""" @@ -215,10 +224,20 @@ def get_stats_dtypes_shapes_from_steps( return result -class CompoundStep: +@dataclass_state +class CompoundStepState(DataClassState): + methods: list[StepMethodState] + + def __init__(self, methods: list[StepMethodState]): + self.methods = methods + + +class CompoundStep(WithSamplingState): """Step method composed of a list of several other step methods applied in sequence.""" + _state_class = CompoundStepState + def __init__(self, methods): self.methods = list(methods) self.stats_dtypes = [] @@ -250,10 +269,27 @@ def reset_tuning(self): if hasattr(method, "reset_tuning"): method.reset_tuning() + @property + def sampling_state(self) -> DataClassState: + return CompoundStepState(methods=[method.sampling_state for method in self.methods]) + + @sampling_state.setter + def sampling_state(self, state: DataClassState): + assert isinstance( + state, self._state_class + ), f"Invalid sampling state class {type(state)}. Expected {self._state_class}" + for method, state_method in zip(self.methods, state.methods): + method.sampling_state = state_method + @property def vars(self) -> list[Variable]: return [var for method in self.methods for var in method.vars] + def set_rng(self, rng: RandomGenerator): + _rngs = get_random_generator(rng, copy=False).spawn(len(self.methods)) + for method, _rng in zip(self.methods, _rngs): + method.set_rng(_rng) + def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: """Flatten a hierarchy of step methods to a list.""" diff --git a/tests/helpers.py b/tests/helpers.py index c14433711b..b9d5c6d019 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -17,11 +17,8 @@ import tempfile import warnings -<<<<<<< HEAD -======= from copy import deepcopy from dataclasses import fields ->>>>>>> 741b38626 (Fixup state) from logging.handlers import BufferingHandler import numpy as np @@ -146,6 +143,21 @@ def step_continuous(self, step_fn, draws, chains=1, tune=1000): _, model_coarse, _ = mv_simple_coarse() with model: step = step_fn(C, model_coarse) + orig_step = deepcopy(step) + orig_state = step.sampling_state + assert equal_sampling_states(step.sampling_state, orig_state) + + ip = model.initial_point() + value1, _ = step.step(ip) + final_state = step.sampling_state + step.sampling_state = orig_state + + value2, _ = step.step(ip) + + assert equal_sampling_states(step.sampling_state, final_state) + assert equal_dataclass_values(value1, value2) + + step.sampling_state = orig_state with warnings.catch_warnings(): warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning) idata = pm.sample( @@ -165,6 +177,14 @@ def step_continuous(self, step_fn, draws, chains=1, tune=1000): self.check_stat(check, idata) self.check_stat_dtype(idata, step) + curr_state = step.sampling_state + assert not equal_sampling_states(orig_state, curr_state) + + orig_step.sampling_state = curr_state + + assert equal_sampling_states(orig_step.sampling_state, curr_state) + assert orig_step.sampling_state is not curr_state + class RVsAssignmentStepsTester: """ From ca2c60b44fa19fcdede6d51d69a78e9be1d3c449 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:54:59 +0200 Subject: [PATCH 5/7] Add metropolis sampling state --- pymc/step_methods/metropolis.py | 110 +++++++++++++++++++++++--- tests/models.py | 11 +++ tests/step_methods/test_metropolis.py | 57 ++++++++++++- 3 files changed, 166 insertions(+), 12 deletions(-) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index aa5101dbb0..e6f9d9dc77 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable +from dataclasses import field +from typing import Any import numpy as np import numpy.random as nr @@ -40,7 +42,8 @@ StatsType, metrop_select, ) -from pymc.step_methods.compound import Competence +from pymc.step_methods.compound import Competence, StepMethodState +from pymc.step_methods.state import dataclass_state __all__ = [ "Metropolis", @@ -111,11 +114,31 @@ def __call__(self, num_draws=None, rng: np.random.Generator | None = None): return np.dot(self.chol, b) +@dataclass_state +class MetropolisState(StepMethodState): + scaling: np.ndarray + tune: bool + steps_until_tune: float + tune_interval: float + accepted_sum: np.ndarray + accept_rate_iter: np.ndarray + accepted_iter: np.ndarray + enum_dims: np.ndarray + + discrete: np.ndarray = field(metadata={"frozen": True}) + any_discrete: bool = field(metadata={"frozen": True}) + all_discrete: bool = field(metadata={"frozen": True}) + elemwise_update: bool = field(metadata={"frozen": True}) + _untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True}) + mode: Any = field(metadata={"frozen": True}) + + class Metropolis(ArrayStepShared): """Metropolis-Hastings sampling step""" name = "metropolis" + default_blocked = False stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), @@ -123,6 +146,8 @@ class Metropolis(ArrayStepShared): "scaling": (np.float64, []), } + _state_class = MetropolisState + def __init__( self, vars=None, @@ -346,6 +371,15 @@ def tune(scale, acc_rate): ) +@dataclass_state +class BinaryMetropolisState(StepMethodState): + tune: bool + accepted: int + scaling: float + tune_interval: int + steps_until_tune: int + + class BinaryMetropolis(ArrayStep): """Metropolis-Hastings optimized for binary variables @@ -375,7 +409,9 @@ class BinaryMetropolis(ArrayStep): "p_jump": (np.float64, []), } - def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): + _state_class = BinaryMetropolisState + + def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, rng=None): model = pm.modelcontext(model) self.scaling = scaling @@ -389,7 +425,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): if not all([v.dtype in pm.discrete_types for v in vars]): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") - super().__init__(vars, [model.compile_logp()]) + super().__init__(vars, [model.compile_logp()], rng=rng) def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -445,6 +481,14 @@ def competence(var): return Competence.INCOMPATIBLE +@dataclass_state +class BinaryGibbsMetropolisState(StepMethodState): + tune: bool + transit_p: int + shuffle_dims: bool + order: list + + class BinaryGibbsMetropolis(ArrayStep): """A Metropolis-within-Gibbs step method optimized for binary variables @@ -472,7 +516,9 @@ class BinaryGibbsMetropolis(ArrayStep): "tune": (bool, []), } - def __init__(self, vars, order="random", transit_p=0.8, model=None): + _state_class = BinaryGibbsMetropolisState + + def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None): model = pm.modelcontext(model) # Doesn't actually tune, but it's required to emit a sampler stat @@ -498,7 +544,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): if not all([v.dtype in pm.discrete_types for v in vars]): raise ValueError("All variables must be binary for BinaryGibbsMetropolis") - super().__init__(vars, [model.compile_logp()]) + super().__init__(vars, [model.compile_logp()], rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -557,6 +603,13 @@ def competence(var): return Competence.INCOMPATIBLE +@dataclass_state +class CategoricalGibbsMetropolisState(StepMethodState): + shuffle_dims: bool + dimcats: list[tuple] + tune: bool + + class CategoricalGibbsMetropolis(ArrayStep): """A Metropolis-within-Gibbs step method optimized for categorical variables. @@ -573,6 +626,8 @@ class CategoricalGibbsMetropolis(ArrayStep): "tune": (bool, []), } + _state_class = CategoricalGibbsMetropolisState + def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None): model = pm.modelcontext(model) @@ -728,6 +783,18 @@ def competence(var): return Competence.INCOMPATIBLE +@dataclass_state +class DEMetropolisState(StepMethodState): + scaling: np.ndarray + lamb: float + tune: str | None + tune_interval: int + steps_until_tune: int + accepted: int + + mode: Any = field(metadata={"frozen": True}) + + class DEMetropolis(PopulationArrayStepShared): """ Differential Evolution Metropolis sampling step. @@ -778,6 +845,8 @@ class DEMetropolis(PopulationArrayStepShared): "lambda": (np.float64, []), } + _state_class = DEMetropolisState + def __init__( self, vars=None, @@ -789,6 +858,7 @@ def __init__( tune_interval=100, model=None, mode=None, + rng=None, **kwargs, ): model = pm.modelcontext(model) @@ -824,7 +894,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared) + super().__init__(vars, shared, rng=rng) def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info @@ -843,9 +913,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: # differential evolution proposal # select two other chains - ir1, ir2 = np.random.choice(self.other_chains, 2, replace=False) - r1 = DictToArrayBijection.map(self.population[ir1]) - r2 = DictToArrayBijection.map(self.population[ir2]) + if self.other_chains is None: # pragma: no cover + raise RuntimeError("Population sampler has not been linked to the other chains") + ir1, ir2 = self.rng.choice(self.other_chains, 2, replace=False) + r1 = DictToArrayBijection.map(self.population[ir1]) # type: ignore + r2 = DictToArrayBijection.map(self.population[ir2]) # type: ignore # propose a jump q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon) @@ -872,6 +944,21 @@ def competence(var, has_grad): return Competence.COMPATIBLE +@dataclass_state +class DEMetropolisZState(StepMethodState): + scaling: np.ndarray + lamb: float + tune: bool + tune_target: str | None + tune_interval: int + steps_until_tune: int + accepted: int + _history: list + + _untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True}) + mode: Any = field(metadata={"frozen": True}) + + class DEMetropolisZ(ArrayStepShared): """ Adaptive Differential Evolution Metropolis sampling step that uses the past to inform jumps. @@ -925,6 +1012,8 @@ class DEMetropolisZ(ArrayStepShared): "lambda": (np.float64, []), } + _state_class = DEMetropolisZState + def __init__( self, vars=None, @@ -937,6 +1026,7 @@ def __init__( tune_drop_fraction: float = 0.9, model=None, mode=None, + rng=None, **kwargs, ): model = pm.modelcontext(model) @@ -984,7 +1074,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared) + super().__init__(vars, shared, rng=rng) def reset_tuning(self): """Resets the tuned sampler parameters and history to their initial values.""" diff --git a/tests/models.py b/tests/models.py index 24f80c7c0b..b66c1dc67d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -186,3 +186,14 @@ def simple_normal(bounded_prior=False): pm.Normal("X_obs", mu=mu_i, sigma=sigma, observed=x0) return model.initial_point(), model, None + + +def simple_binary(): + p1 = 0.5 + p2 = 0.5 + + with pm.Model() as model: + pm.Bernoulli("d1", p=p1) + pm.Bernoulli("d2", p=p2) + + return model.initial_point(), model, (p1, p2) diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index f414a534e8..a73538a61b 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -14,6 +14,8 @@ import warnings +from copy import deepcopy + import arviz as az import numpy as np import numpy.testing as npt @@ -24,6 +26,7 @@ from pymc.step_methods.metropolis import ( BinaryGibbsMetropolis, + BinaryMetropolis, CategoricalGibbsMetropolis, DEMetropolis, DEMetropolisZ, @@ -31,10 +34,17 @@ MultivariateNormalProposal, NormalProposal, ) +from pymc.step_methods.state import equal_dataclass_values from pymc.testing import fast_unstable_sampling_mode from tests import sampler_fixtures as sf -from tests.helpers import RVsAssignmentStepsTester, StepMethodTester -from tests.models import mv_simple, mv_simple_discrete, simple_categorical +from tests.helpers import RVsAssignmentStepsTester, StepMethodTester, equal_sampling_states +from tests.models import ( + mv_simple, + mv_simple_discrete, + simple_binary, + simple_categorical, + simple_model, +) SEED = sum(ord(c) for c in "test_metropolis") @@ -47,6 +57,7 @@ class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture): min_n_eff = 10000 rtol = 0.1 atol = 0.05 + ks_thin = 10 step_args = {"rng": np.random.default_rng(SEED)} @@ -367,3 +378,45 @@ def test_discrete_steps(self, step, step_kwargs): ) def test_continuous_steps(self, step, step_kwargs): self.continuous_steps(step, step_kwargs) + + +@pytest.mark.parametrize( + ["step_method", "model_fn"], + [ + [Metropolis, simple_model], + [BinaryMetropolis, simple_binary], + [BinaryGibbsMetropolis, simple_binary], + [CategoricalGibbsMetropolis, simple_categorical], + [DEMetropolis, simple_model], + [DEMetropolisZ, simple_model], + ], +) +def test_sampling_state(step_method, model_fn): + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + initial_point, model, _ = model_fn() + with model: + sampler = step_method(model.value_vars) + if hasattr(sampler, "link_population"): + sampler.link_population([initial_point] * 100, 0) + sampler_orig = deepcopy(sampler) + state_orig = sampler_orig.sampling_state + + sample1, stat1 = sampler.step(initial_point) + sampler.tune = False + + final_state1 = sampler.sampling_state + + assert not equal_sampling_states(final_state1, state_orig) + + sampler.sampling_state = state_orig + + assert equal_sampling_states(sampler.sampling_state, state_orig) + + sample2, stat2 = sampler.step(initial_point) + sampler.tune = False + + final_state2 = sampler.sampling_state + + assert equal_sampling_states(final_state1, final_state2) + assert equal_dataclass_values(sample1, sample2) + assert equal_dataclass_values(stat1, stat2) From 04fbe64fb2a602290fb088bcfe69c1d7c620fcea Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:55:22 +0200 Subject: [PATCH 6/7] Add slice sampling state --- pymc/step_methods/slicer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 3e096aeb9f..2ea4b1f55f 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -21,7 +21,8 @@ from pymc.model import modelcontext from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared -from pymc.step_methods.compound import Competence +from pymc.step_methods.compound import Competence, StepMethodState +from pymc.step_methods.state import dataclass_state from pymc.util import get_value_vars_from_user_vars from pymc.vartypes import continuous_types @@ -30,6 +31,17 @@ LOOP_ERR_MSG = "max slicer iters %d exceeded" +dataclass_state + + +@dataclass_state +class SliceState(StepMethodState): + w: np.ndarray + tune: bool + n_tunes: float + iter_limit: float + + class Slice(ArrayStepShared): """ Univariate slice sampler step method. @@ -61,6 +73,8 @@ class Slice(ArrayStepShared): "nstep_in": (int, []), } + _state_class = SliceState + def __init__( self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs ): From af74f2cd46c04a66f3e8e7cad7c3444b44a725c1 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:55:46 +0200 Subject: [PATCH 7/7] Add HMC sampling state --- pymc/step_methods/hmc/base_hmc.py | 34 +++++-- pymc/step_methods/hmc/hmc.py | 10 +- pymc/step_methods/hmc/nuts.py | 10 +- pymc/step_methods/hmc/quadpotential.py | 123 ++++++++++++++++++++++--- pymc/step_methods/step_sizes.py | 21 ++++- 5 files changed, 178 insertions(+), 20 deletions(-) diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index b320ed8194..87daff649c 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -27,14 +27,19 @@ from pymc.model import Point, modelcontext from pymc.pytensorf import floatX from pymc.stats.convergence import SamplerWarning, WarningType -from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.compound import StepMethodState from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.integration import IntegrationError, State -from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential +from pymc.step_methods.hmc.quadpotential import ( + PotentialState, + QuadPotentialDiagAdapt, + quad_potential, +) +from pymc.step_methods.state import dataclass_state +from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState from pymc.tuning import guess_scaling -from pymc.util import get_value_vars_from_user_vars +from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars logger = logging.getLogger(__name__) @@ -53,12 +58,27 @@ class HMCStepData(NamedTuple): stats: dict[str, Any] +@dataclass_state +class BaseHMCState(StepMethodState): + adapt_step_size: bool + Emax: float + iter_count: int + step_size: np.ndarray + step_adapt: StepSizeState + target_accept: float + tune: bool + potential: PotentialState + _num_divs_sample: int + + class BaseHMC(GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" integrator: integration.CpuLeapfrogIntegrator default_blocked = True + _state_class = BaseHMCState + def __init__( self, vars=None, @@ -134,9 +154,7 @@ def __init__( size = sum(v.size for v in nuts_vars) self.step_size = step_scale / (size**0.25) - self.step_adapt = step_sizes.DualAverageAdaptation( - self.step_size, target_accept, gamma, k, t0 - ) + self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0) self.target_accept = target_accept self.tune = True @@ -268,3 +286,7 @@ def reset_tuning(self, start=None): def reset(self, start=None): self.tune = True self.potential.reset() + + def set_rng(self, rng: RandomGenerator): + self.rng = get_random_generator(rng, copy=False) + self.potential.set_rng(self.rng.spawn(1)[0]) diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 106faee501..a5ebbd7a8c 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -14,14 +14,16 @@ from __future__ import annotations +from dataclasses import field from typing import Any import numpy as np from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence -from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData +from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.state import dataclass_state from pymc.vartypes import discrete_types __all__ = ["HamiltonianMC"] @@ -31,6 +33,12 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non return (rng or np.random).uniform(elow, ehigh) * step_size +@dataclass_state +class HamiltonianMCState(BaseHMCState): + path_length: float = field(metadata={"frozen": True}) + max_steps: int = field(metadata={"frozen": True}) + + class HamiltonianMC(BaseHMC): R"""A sampler for continuous variables based on Hamiltonian mechanics. diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 3c4b4e6800..9bcde95104 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import namedtuple +from dataclasses import field import numpy as np @@ -23,13 +24,20 @@ from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc import integration -from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData +from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.state import dataclass_state from pymc.vartypes import continuous_types __all__ = ["NUTS"] +@dataclass_state +class NUTSState(BaseHMCState): + max_treedepth: int = field(metadata={"frozen": True}) + early_max_treedepth: int = field(metadata={"frozen": True}) + + class NUTS(BaseHMC): r"""A sampler for continuous variables based on Hamiltonian mechanics. diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index abddaaf35f..05da188f9b 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -16,7 +16,8 @@ import warnings -from typing import overload +from dataclasses import field +from typing import Any, overload import numpy as np import pytensor @@ -25,6 +26,8 @@ from scipy.sparse import issparse from pymc.pytensorf import floatX +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from pymc.util import RandomGenerator, get_random_generator __all__ = [ "quad_potential", @@ -100,11 +103,18 @@ def __str__(self): return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}." -class QuadPotential: +@dataclass_state +class PotentialState(DataClassState): + rng: np.random.Generator + + +class QuadPotential(WithSamplingState): dtype: np.dtype + _state_class = PotentialState + def __init__(self, rng=None): - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) @overload def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ... @@ -157,15 +167,42 @@ def reset(self): def stats(self): return {"largest_eigval": np.nan, "smallest_eigval": np.nan} + def set_rng(self, rng: RandomGenerator): + self.rng = get_random_generator(rng, copy=False) + def isquadpotential(value): """Check whether an object might be a QuadPotential object.""" return isinstance(value, QuadPotential) +@dataclass_state +class QuadPotentialDiagAdaptState(PotentialState): + _var: np.ndarray + _stds: np.ndarray + _inv_stds: np.ndarray + _foreground_var: WeightedVarianceState + _background_var: WeightedVarianceState + _n_samples: int + adaptation_window: int + _mass_trace: list[np.ndarray] | None + + dtype: Any = field(metadata={"frozen": True}) + _n: int = field(metadata={"frozen": True}) + _discard_window: int = field(metadata={"frozen": True}) + _early_update: int = field(metadata={"frozen": True}) + _initial_mean: np.ndarray = field(metadata={"frozen": True}) + _initial_diag: np.ndarray = field(metadata={"frozen": True}) + _initial_weight: np.ndarray = field(metadata={"frozen": True}) + adaptation_window_multiplier: float = field(metadata={"frozen": True}) + _store_mass_matrix_trace: bool = field(metadata={"frozen": True}) + + class QuadPotentialDiagAdapt(QuadPotential): """Adapt a diagonal mass matrix from the sample variances.""" + _state_class = QuadPotentialDiagAdaptState + def __init__( self, n, @@ -346,9 +383,20 @@ def raise_ok(self, map_info): raise ValueError("\n".join(errmsg)) -class _WeightedVariance: +@dataclass_state +class WeightedVarianceState(DataClassState): + n_samples: int + mean: np.ndarray + raw_var: np.ndarray + + _dtype: Any = field(metadata={"frozen": True}) + + +class _WeightedVariance(WithSamplingState): """Online algorithm for computing mean of variance.""" + _state_class = WeightedVarianceState + def __init__( self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d" ): @@ -390,7 +438,16 @@ def current_mean(self): return self.mean.copy(dtype=self._dtype) -class _ExpWeightedVariance: +@dataclass_state +class ExpWeightedVarianceState(DataClassState): + _alpha: float + _mean: np.ndarray + _var: np.ndarray + + +class _ExpWeightedVariance(WithSamplingState): + _state_class = ExpWeightedVarianceState + def __init__(self, n_vars, *, init_mean, init_var, alpha): self._variance = init_var self._mean = init_mean @@ -415,7 +472,18 @@ def current_mean(self, out=None): return out +@dataclass_state +class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState): + _alpha: float + _stop_adaptation: float + _variance_estimator: ExpWeightedVarianceState + + _variance_estimator_grad: ExpWeightedVarianceState | None = None + + class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt): + _state_class = QuadPotentialDiagAdaptExpState + def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs): """Set up a diagonal mass matrix. @@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None): self.s = s self.inv_s = 1.0 / s self.v = v - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None): dtype = pytensor.config.floatX self.dtype = dtype self.L = floatX(scipy.linalg.cholesky(A, lower=True)) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None): self._cov = np.array(cov, dtype=self.dtype, copy=True) self._chol = scipy.linalg.cholesky(self._cov, lower=True) self._n = len(self._cov) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out): __call__ = random +@dataclass_state +class QuadPotentialFullAdaptState(PotentialState): + _previous_update: int + _cov: np.ndarray + _chol: np.ndarray + _chol_error: scipy.linalg.LinAlgError | ValueError | None = None + _foreground_cov: WeightedCovarianceState + _background_cov: WeightedCovarianceState + _n_samples: int + adaptation_window: int + + dtype: Any = field(metadata={"frozen": True}) + _n: int = field(metadata={"frozen": True}) + _update_window: int = field(metadata={"frozen": True}) + _initial_mean: np.ndarray = field(metadata={"frozen": True}) + _initial_cov: np.ndarray = field(metadata={"frozen": True}) + _initial_weight: np.ndarray = field(metadata={"frozen": True}) + adaptation_window_multiplier: float = field(metadata={"frozen": True}) + + class QuadPotentialFullAdapt(QuadPotentialFull): """Adapt a dense mass matrix using the sample covariances.""" + _state_class = QuadPotentialFullAdaptState + def __init__( self, n, @@ -689,7 +779,7 @@ def __init__( self.adaptation_window_multiplier = float(adaptation_window_multiplier) self._update_window = int(update_window) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) self.reset() @@ -742,7 +832,16 @@ def raise_ok(self, vmap): raise ValueError(str(self._chol_error)) -class _WeightedCovariance: +@dataclass_state +class WeightedCovarianceState(DataClassState): + n_samples: float + mean: np.ndarray + raw_cov: np.ndarray + + _dtype: Any = field(metadata={"frozen": True}) + + +class _WeightedCovariance(WithSamplingState): """Online algorithm for computing mean and covariance This implements the `Welford's algorithm @@ -752,6 +851,8 @@ class _WeightedCovariance: """ + _state_class = WeightedCovarianceState + def __init__( self, nelem, @@ -827,7 +928,7 @@ def __init__(self, A, rng=None): self.size = A.shape[0] self.factor = factor = cholmod.cholesky(A) self.d_sqrt = np.sqrt(factor.D()) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x): """Compute the current velocity at a position in parameter space.""" diff --git a/pymc/step_methods/step_sizes.py b/pymc/step_methods/step_sizes.py index 6c2b7340fd..c0fdb934a3 100644 --- a/pymc/step_methods/step_sizes.py +++ b/pymc/step_methods/step_sizes.py @@ -12,14 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np from scipy import stats from pymc.stats.convergence import SamplerWarning, WarningType +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state + + +@dataclass_state +class StepSizeState(DataClassState): + _log_step: np.ndarray + _log_bar: np.ndarray + _hbar: float + _count: int + _mu: np.ndarray + _tuned_stats: list + _initial_step: np.ndarray + _target: float + _k: float + _t0: float + _gamma: float + +class DualAverageAdaptation(WithSamplingState): + _state_class = StepSizeState -class DualAverageAdaptation: def __init__(self, initial_step, target, gamma, k, t0): self._initial_step = initial_step self._target = target