Skip to content

Commit 45801f5

Browse files
Incrementally update the RNG state in Model
1 parent d95827f commit 45801f5

File tree

7 files changed

+118
-53
lines changed

7 files changed

+118
-53
lines changed

pymc3/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __new__(cls, name, *args, **kwargs):
137137
rng = kwargs.pop("rng", None)
138138

139139
if rng is None:
140-
rng = model.default_rng
140+
rng = model.next_rng()
141141

142142
if not isinstance(name, string_types):
143143
raise TypeError(f"Name needs to be a string but got: {name}")

pymc3/model.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from aesara.graph.basic import Constant, Variable, graph_inputs
4343
from aesara.graph.fg import FunctionGraph, MissingInputError
4444
from aesara.tensor.random.opt import local_subtensor_rv_lift
45+
from aesara.tensor.random.var import RandomStateSharedVariable
4546
from aesara.tensor.sharedvar import ScalarSharedVariable
4647
from aesara.tensor.var import TensorVariable
4748
from pandas import Series
@@ -532,6 +533,13 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
532533
parameters can only take on valid values you can set this to
533534
False for increased speed. This should not be used if your model
534535
contains discrete variables.
536+
rng_seeder: int or numpy.random.RandomState
537+
The ``numpy.random.RandomState`` used to seed the
538+
``RandomStateSharedVariable`` sequence used by a model
539+
``RandomVariable``s, or an int used to seed a new
540+
``numpy.random.RandomState``. If ``None``, a
541+
``RandomStateSharedVariable`` will be generated and used. Incremental
542+
access to the state sequence is provided by ``Model.next_rng``.
535543
536544
Examples
537545
--------
@@ -615,17 +623,31 @@ def __new__(cls, *args, **kwargs):
615623
instance._aesara_config = kwargs.get("aesara_config", {})
616624
return instance
617625

618-
def __init__(self, name="", model=None, aesara_config=None, coords=None, check_bounds=True):
626+
def __init__(
627+
self,
628+
name="",
629+
model=None,
630+
aesara_config=None,
631+
coords=None,
632+
check_bounds=True,
633+
rng_seeder: Optional[Union[int, np.random.RandomState]] = None,
634+
):
619635
self.name = name
620636
self._coords = {}
621637
self._RV_dims = {}
622638
self._dim_lengths = {}
623639
self.add_coords(coords)
624640
self.check_bounds = check_bounds
625641

626-
self.default_rng = aesara.shared(np.random.RandomState(), name="default_rng", borrow=True)
627-
self.default_rng.tag.is_rng = True
628-
self.default_rng.default_update = self.default_rng
642+
if rng_seeder is None:
643+
self.rng_seeder = np.random.RandomState()
644+
elif isinstance(rng_seeder, int):
645+
self.rng_seeder = np.random.RandomState(rng_seeder)
646+
else:
647+
self.rng_seeder = rng_seeder
648+
649+
# The sequence of model-generated RNGs
650+
self.rng_seq = []
629651

630652
if self.parent is not None:
631653
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -931,6 +953,20 @@ def cont_vars(self):
931953
"""All the continuous variables in the model"""
932954
return list(typefilter(self.value_vars, continuous_types))
933955

956+
def next_rng(self) -> RandomStateSharedVariable:
957+
"""Generate a new ``RandomStateSharedVariable``.
958+
959+
The new ``RandomStateSharedVariable`` is also added to
960+
``Model.rng_seq``.
961+
"""
962+
new_seed = self.rng_seeder.randint(2 ** 30, dtype=np.int64)
963+
next_rng = aesara.shared(np.random.RandomState(new_seed), borrow=True)
964+
next_rng.tag.is_rng = True
965+
966+
self.rng_seq.append(next_rng)
967+
968+
return next_rng
969+
934970
def shape_from_dims(self, dims):
935971
shape = []
936972
if len(set(dims)) != len(dims):

pymc3/sampling.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def sample(
339339
time until completion ("expected time of arrival"; ETA).
340340
model : Model (optional if in ``with`` context)
341341
random_seed : int or list of ints
342-
A list is accepted if ``cores`` is greater than one.
342+
Random seed(s) used by the sampling steps. A list is accepted if
343+
``cores`` is greater than one.
343344
discard_tuned_samples : bool
344345
Whether to discard posterior samples of the tune interval.
345346
compute_convergence_checks : bool, default=True
@@ -467,10 +468,6 @@ def sample(
467468
np.random.seed(random_seed)
468469
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
469470

470-
# TODO: We need to do something about multiple seeds and this single,
471-
# shared RNG state.
472-
model.default_rng.get_value(borrow=True).seed(random_seed)
473-
474471
if not isinstance(random_seed, abc.Iterable):
475472
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
476473

@@ -1004,9 +1001,7 @@ def _iter_sample(
10041001
"""
10051002
model = modelcontext(model)
10061003
draws = int(draws)
1007-
if random_seed is not None:
1008-
# np.random.seed(random_seed)
1009-
model.default_rng.get_value(borrow=True).seed(random_seed)
1004+
10101005
if draws < 1:
10111006
raise ValueError("Argument `draws` must be greater than 0.")
10121007

@@ -1273,9 +1268,7 @@ def _prepare_iter_population(
12731268
nchains = len(chains)
12741269
model = modelcontext(model)
12751270
draws = int(draws)
1276-
if random_seed is not None:
1277-
# np.random.seed(random_seed)
1278-
model.default_rng.get_value(borrow=True).seed(random_seed)
1271+
12791272
if draws < 1:
12801273
raise ValueError("Argument `draws` should be above 0.")
12811274

@@ -1693,8 +1686,12 @@ def sample_posterior_predictive(
16931686
vars_ = model.observed_RVs + model.auto_deterministics
16941687

16951688
if random_seed is not None:
1696-
# np.random.seed(random_seed)
1697-
model.default_rng.get_value(borrow=True).seed(random_seed)
1689+
warnings.warn(
1690+
"In this version, RNG seeding is managed by the Model objects. "
1691+
"See the `rng_seeder` argument in Model's constructor.",
1692+
DeprecationWarning,
1693+
stacklevel=2,
1694+
)
16981695

16991696
indices = np.arange(samples)
17001697

@@ -1816,8 +1813,6 @@ def sample_posterior_predictive_w(
18161813
Dictionary with the variables as keys. The values corresponding to the
18171814
posterior predictive samples from the weighted models.
18181815
"""
1819-
# np.random.seed(random_seed)
1820-
18211816
if isinstance(traces[0], InferenceData):
18221817
n_samples = [
18231818
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
@@ -1832,9 +1827,15 @@ def sample_posterior_predictive_w(
18321827
if models is None:
18331828
models = [modelcontext(models)] * len(traces)
18341829

1830+
if random_seed:
1831+
warnings.warn(
1832+
"In this version, RNG seeding is managed by the Model objects. "
1833+
"See the `rng_seeder` argument in Model's constructor.",
1834+
DeprecationWarning,
1835+
stacklevel=2,
1836+
)
1837+
18351838
for model in models:
1836-
if random_seed:
1837-
model.default_rng.get_value(borrow=True).seed(random_seed)
18381839
if model.potentials:
18391840
warnings.warn(
18401841
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
@@ -1937,6 +1938,7 @@ def sample_prior_predictive(
19371938
model: Optional[Model] = None,
19381939
var_names: Optional[Iterable[str]] = None,
19391940
random_seed=None,
1941+
mode: Optional[Union[str, Mode]] = None,
19401942
) -> Dict[str, np.ndarray]:
19411943
"""Generate samples from the prior predictive distribution.
19421944
@@ -1950,6 +1952,8 @@ def sample_prior_predictive(
19501952
samples. Defaults to both observed and unobserved RVs.
19511953
random_seed : int
19521954
Seed for the random number generator.
1955+
mode:
1956+
The mode used by ``aesara.function`` to compile the graph.
19531957
19541958
Returns
19551959
-------
@@ -1977,8 +1981,12 @@ def sample_prior_predictive(
19771981
vars_ = set(var_names)
19781982

19791983
if random_seed is not None:
1980-
# np.random.seed(random_seed)
1981-
model.default_rng.get_value(borrow=True).seed(random_seed)
1984+
warnings.warn(
1985+
"In this version, RNG seeding is managed by the Model objects. "
1986+
"See the `rng_seeder` argument in Model's constructor.",
1987+
DeprecationWarning,
1988+
stacklevel=2,
1989+
)
19821990

19831991
names = get_default_varnames(vars_, include_transformed=False)
19841992

@@ -2127,8 +2135,6 @@ def init_nuts(
21272135

21282136
if random_seed is not None:
21292137
random_seed = int(np.atleast_1d(random_seed)[0])
2130-
# np.random.seed(random_seed)
2131-
model.default_rng.get_value(borrow=True).seed(random_seed)
21322138

21332139
cb = [
21342140
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,3 +2944,25 @@ def func(x):
29442944
import pickle
29452945

29462946
pickle.loads(pickle.dumps(y))
2947+
2948+
2949+
def test_distinct_rvs():
2950+
"""Make sure `RandomVariable`s generated using a `Model`'s default RNG state all have distinct states."""
2951+
2952+
with pm.Model(rng_seeder=np.random.RandomState(2023532)) as model:
2953+
X_rv = pm.Normal("x")
2954+
Y_rv = pm.Normal("y")
2955+
2956+
pp_samples = pm.sample_prior_predictive(samples=2)
2957+
2958+
assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
2959+
2960+
assert len(model.rng_seq) == 2
2961+
2962+
with pm.Model(rng_seeder=np.random.RandomState(2023532)):
2963+
X_rv = pm.Normal("x")
2964+
Y_rv = pm.Normal("y")
2965+
2966+
pp_samples_2 = pm.sample_prior_predictive(samples=2)
2967+
2968+
assert np.array_equal(pp_samples["y"], pp_samples_2["y"])

pymc3/tests/test_ndarray_backend.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def test_combine_true_squeeze_true(self):
209209

210210
class TestSaveLoad:
211211
@staticmethod
212-
def model():
213-
with pm.Model() as model:
212+
def model(rng_seeder=None):
213+
with pm.Model(rng_seeder=rng_seeder) as model:
214214
x = pm.Normal("x", 0, 1)
215215
y = pm.Normal("y", x, 1, observed=2)
216216
z = pm.Normal("z", x + y, 1)
@@ -267,15 +267,16 @@ def test_sample_posterior_predictive(self, tmpdir_factory):
267267

268268
assert save_dir == directory
269269

270-
with TestSaveLoad.model() as model:
271-
model.default_rng.get_value(borrow=True).seed(10)
270+
rng = np.random.RandomState(10)
271+
272+
with TestSaveLoad.model(rng_seeder=rng):
272273
ppc = pm.sample_posterior_predictive(self.trace)
273274

274-
with TestSaveLoad.model() as model:
275+
rng = np.random.RandomState(10)
276+
277+
with TestSaveLoad.model(rng_seeder=rng):
275278
trace2 = pm.load_trace(directory)
276-
model.default_rng.get_value(borrow=True).seed(10)
277279
ppc2 = pm.sample_posterior_predictive(trace2)
278-
ppc2f = pm.sample_posterior_predictive(trace2)
279280

280281
for key, value in ppc.items():
281282
assert (value == ppc2[key]).all()

pymc3/tests/test_sampling.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,13 @@ def test_model_not_drawable_prior(self):
619619
assert samples["foo"].shape == (40, 200)
620620

621621
def test_model_shared_variable(self):
622-
x = np.random.randn(100)
622+
rng = np.random.RandomState(9832)
623+
624+
x = rng.randn(100)
623625
y = x > 0
624626
x_shared = aesara.shared(x)
625627
y_shared = aesara.shared(y)
626-
with pm.Model() as model:
628+
with pm.Model(rng_seeder=rng) as model:
627629
coeff = pm.Normal("x", mu=0, sd=1)
628630
logistic = pm.Deterministic("p", pm.math.sigmoid(coeff * x_shared))
629631

@@ -644,12 +646,12 @@ def test_model_shared_variable(self):
644646
npt.assert_allclose(post_pred["p"], expected_p)
645647

646648
def test_deterministic_of_observed(self):
647-
np.random.seed(8442)
649+
rng = np.random.RandomState(8442)
648650

649-
meas_in_1 = pm.aesaraf.floatX(2 + 4 * np.random.randn(10))
650-
meas_in_2 = pm.aesaraf.floatX(5 + 4 * np.random.randn(10))
651+
meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10))
652+
meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10))
651653
nchains = 2
652-
with pm.Model() as model:
654+
with pm.Model(rng_seeder=rng) as model:
653655
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
654656
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
655657
mu_in_2 = pm.Normal("mu_in_2", 0, 1)
@@ -660,7 +662,6 @@ def test_deterministic_of_observed(self):
660662
out_diff = in_1 + in_2
661663
pm.Deterministic("out", out_diff)
662664

663-
model.default_rng.get_value(borrow=True).seed(0)
664665
trace = pm.sample(
665666
100,
666667
chains=nchains,
@@ -670,7 +671,6 @@ def test_deterministic_of_observed(self):
670671

671672
rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4
672673

673-
model.default_rng.get_value(borrow=True).seed(0)
674674
ppc = pm.sample_posterior_predictive(
675675
model=model,
676676
trace=trace,
@@ -682,11 +682,11 @@ def test_deterministic_of_observed(self):
682682
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
683683

684684
def test_deterministic_of_observed_modified_interface(self):
685-
np.random.seed(4982)
685+
rng = np.random.RandomState(4982)
686686

687-
meas_in_1 = pm.aesaraf.floatX(2 + 4 * np.random.randn(100))
688-
meas_in_2 = pm.aesaraf.floatX(5 + 4 * np.random.randn(100))
689-
with pm.Model() as model:
687+
meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(100))
688+
meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(100))
689+
with pm.Model(rng_seeder=rng) as model:
690690
mu_in_1 = pm.Normal("mu_in_1", 0, 1, testval=0)
691691
sigma_in_1 = pm.HalfNormal("sd_in_1", 1, testval=1)
692692
mu_in_2 = pm.Normal("mu_in_2", 0, 1, testval=0)
@@ -969,12 +969,10 @@ def test_multivariate2(self):
969969
assert sim_ppc["obs"].shape == (20,) + mn_data.shape
970970

971971
def test_layers(self):
972-
with pm.Model() as model:
972+
with pm.Model(rng_seeder=232093) as model:
973973
a = pm.Uniform("a", lower=0, upper=1, size=10)
974974
b = pm.Binomial("b", n=1, p=a, size=10)
975975

976-
model.default_rng.get_value(borrow=True).seed(232093)
977-
978976
b_sampler = aesara.function([], b)
979977
avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
980978
npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)

pymc3/tests/test_step.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,9 @@ def perform(self, node, inputs, outputs):
16571657
mout = []
16581658
coarse_models = []
16591659

1660-
with Model() as coarse_model_0:
1660+
rng = np.random.RandomState(seed)
1661+
1662+
with Model(rng_seeder=rng) as coarse_model_0:
16611663
if aesara.config.floatX == "float32":
16621664
Q = Data("Q", np.float32(0.0))
16631665
else:
@@ -1674,9 +1676,9 @@ def perform(self, node, inputs, outputs):
16741676

16751677
coarse_models.append(coarse_model_0)
16761678

1677-
coarse_model_0.default_rng.get_value(borrow=True).seed(seed)
1679+
rng = np.random.RandomState(seed)
16781680

1679-
with Model() as coarse_model_1:
1681+
with Model(rng_seeder=rng) as coarse_model_1:
16801682
if aesara.config.floatX == "float32":
16811683
Q = Data("Q", np.float32(0.0))
16821684
else:
@@ -1693,9 +1695,9 @@ def perform(self, node, inputs, outputs):
16931695

16941696
coarse_models.append(coarse_model_1)
16951697

1696-
coarse_model_1.default_rng.get_value(borrow=True).seed(seed)
1698+
rng = np.random.RandomState(seed)
16971699

1698-
with Model() as model:
1700+
with Model(rng_seeder=rng) as model:
16991701
if aesara.config.floatX == "float32":
17001702
Q = Data("Q", np.float32(0.0))
17011703
else:

0 commit comments

Comments
 (0)