Skip to content

Commit 84f6207

Browse files
Convert RandomVariables to in-place during graph optimization
1 parent 45801f5 commit 84f6207

File tree

7 files changed

+53
-27
lines changed

7 files changed

+53
-27
lines changed

pymc3/aesaraf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import scipy.sparse as sps
3232

3333
from aesara import config, scalar
34+
from aesara.compile.mode import Mode, get_mode
3435
from aesara.gradient import grad
3536
from aesara.graph.basic import (
3637
Apply,
@@ -861,3 +862,16 @@ def take_along_axis(arr, indices, axis=0):
861862

862863
# use the fancy index
863864
return arr[_make_along_axis_idx(arr_shape, indices, _axis)]
865+
866+
867+
def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
868+
"""Use ``aesara.function`` with the random_make_inplace optimization always enabled.
869+
870+
Using this function ensures that compiled functions containing random
871+
variables will produce new samples on each call.
872+
"""
873+
mode = get_mode(mode)
874+
opt_qry = mode.provided_optimizer.including("random_make_inplace")
875+
mode = Mode(linker=mode.linker, optimizer=opt_qry)
876+
aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
877+
return aesara_function

pymc3/distributions/distribution.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
import warnings
2020

2121
from abc import ABCMeta
22-
from copy import copy
2322
from typing import TYPE_CHECKING
2423

2524
import dill
2625

2726
from aesara.tensor.random.op import RandomVariable
27+
from aesara.tensor.random.var import RandomStateSharedVariable
2828

2929
from pymc3.distributions import _logcdf, _logp
3030

@@ -77,14 +77,6 @@ def _random(*args, **kwargs):
7777
rv_type = None
7878

7979
if isinstance(rv_op, RandomVariable):
80-
if not rv_op.inplace:
81-
# TODO: This is a temporary work-around.
82-
# Remove this once we know what we want regarding RNG states
83-
# and their propagation.
84-
rv_op = copy(rv_op)
85-
rv_op.inplace = True
86-
clsdict["rv_op"] = rv_op
87-
8880
rv_type = type(rv_op)
8981

9082
new_cls = super().__new__(cls, name, bases, clsdict)
@@ -158,15 +150,31 @@ def __new__(cls, name, *args, **kwargs):
158150
return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
159151

160152
@classmethod
161-
def dist(cls, dist_params, **kwargs):
153+
def dist(cls, dist_params, rng=None, **kwargs):
162154

163155
testval = kwargs.pop("testval", None)
164156

165-
rv_var = cls.rv_op(*dist_params, **kwargs)
157+
rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs)
166158

167159
if testval is not None:
168160
rv_var.tag.test_value = testval
169161

162+
if (
163+
rv_var.owner
164+
and isinstance(rv_var.owner.op, RandomVariable)
165+
and isinstance(rng, RandomStateSharedVariable)
166+
and not getattr(rng, "default_update", None)
167+
):
168+
# This tells `aesara.function` that the shared RNG variable
169+
# is mutable, which--in turn--tells the `FunctionGraph`
170+
# `Supervisor` feature to allow in-place updates on the variable.
171+
# Without it, the `RandomVariable`s could not be optimized to allow
172+
# in-place RNG updates, forcing all sample results from compiled
173+
# functions to be the same on repeated evaluations.
174+
new_rng = rv_var.owner.outputs[0]
175+
rv_var.update = (rng, new_rng)
176+
rng.default_update = new_rng
177+
170178
return rv_var
171179

172180
def _distr_parameters_for_repr(self):

pymc3/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
from pymc3.aesaraf import (
5151
change_rv_size,
52+
compile_rv_inplace,
5253
gradient,
5354
hessian,
5455
inputvars,
@@ -455,7 +456,7 @@ def __init__(
455456

456457
inputs = grad_vars
457458

458-
self._aesara_function = aesara.function(inputs, outputs, givens=givens, **kwargs)
459+
self._aesara_function = compile_rv_inplace(inputs, outputs, givens=givens, **kwargs)
459460

460461
def set_weights(self, values):
461462
if values.shape != (self._n_costs - 1,):
@@ -1378,7 +1379,7 @@ def makefn(self, outs, mode=None, *args, **kwargs):
13781379
Compiled Aesara function
13791380
"""
13801381
with self:
1381-
return aesara.function(
1382+
return compile_rv_inplace(
13821383
self.value_vars,
13831384
outs,
13841385
allow_input_downcast=True,

pymc3/sampling.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@
2525
from copy import copy, deepcopy
2626
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
2727

28-
import aesara
2928
import aesara.gradient as tg
3029
import numpy as np
3130
import packaging
3231
import xarray
3332

33+
from aesara.compile.mode import Mode
3434
from aesara.tensor.sharedvar import SharedVariable
3535
from arviz import InferenceData
3636
from fastprogress.fastprogress import progress_bar
3737

3838
import pymc3 as pm
3939

40-
from pymc3.aesaraf import change_rv_size, inputvars, walk_model
40+
from pymc3.aesaraf import change_rv_size, compile_rv_inplace, inputvars, walk_model
4141
from pymc3.backends.arviz import _DefaultTrace
4242
from pymc3.backends.base import BaseTrace, MultiTrace
4343
from pymc3.backends.ndarray import NDArray
@@ -1584,6 +1584,7 @@ def sample_posterior_predictive(
15841584
keep_size: Optional[bool] = False,
15851585
random_seed=None,
15861586
progressbar: bool = True,
1587+
mode: Optional[Union[str, Mode]] = None,
15871588
) -> Dict[str, np.ndarray]:
15881589
"""Generate posterior predictive samples from a model given a trace.
15891590
@@ -1617,6 +1618,8 @@ def sample_posterior_predictive(
16171618
Whether or not to display a progress bar in the command line. The bar shows the percentage
16181619
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
16191620
time until completion ("expected time of arrival"; ETA).
1621+
mode:
1622+
The mode used by ``aesara.function`` to compile the graph.
16201623
16211624
Returns
16221625
-------
@@ -1727,12 +1730,13 @@ def sample_posterior_predictive(
17271730
if size is not None:
17281731
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
17291732

1730-
sampler_fn = aesara.function(
1733+
sampler_fn = compile_rv_inplace(
17311734
inputs,
17321735
vars_to_sample,
17331736
allow_input_downcast=True,
17341737
accept_inplace=True,
17351738
on_unused_input="ignore",
1739+
mode=mode,
17361740
)
17371741

17381742
ppc_trace_t = _DefaultTrace(samples)
@@ -1992,12 +1996,11 @@ def sample_prior_predictive(
19921996

19931997
vars_to_sample = [model[name] for name in names]
19941998
inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, SharedVariable)]
1995-
sampler_fn = aesara.function(
1996-
inputs,
1997-
vars_to_sample,
1998-
allow_input_downcast=True,
1999-
accept_inplace=True,
1999+
2000+
sampler_fn = compile_rv_inplace(
2001+
inputs, vars_to_sample, allow_input_downcast=True, accept_inplace=True, mode=mode
20002002
)
2003+
20012004
values = zip(*[sampler_fn() for i in range(samples)])
20022005

20032006
data = {k: np.stack(v) for k, v in zip(names, values)}

pymc3/sampling_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
88
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
99

10-
import aesara.graph.fg
1110
import aesara.tensor as at
1211
import arviz as az
1312
import jax
@@ -23,6 +22,7 @@
2322
from aesara.tensor.type import TensorType
2423

2524
from pymc3 import modelcontext
25+
from pymc3.aesaraf import compile_rv_inplace
2626

2727
warnings.warn("This module is experimental.")
2828

@@ -209,7 +209,7 @@ def sample_numpyro_nuts(
209209
print("Compiling...")
210210

211211
tic1 = pd.Timestamp.now()
212-
_sample = aesara.function(
212+
_sample = compile_rv_inplace(
213213
[],
214214
sample_outputs + [numpyro_samples[-1]],
215215
allow_input_downcast=True,

pymc3/step_methods/metropolis.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
from typing import Any, Callable, Dict, List, Tuple
1515

16-
import aesara
1716
import numpy as np
1817
import numpy.random as nr
1918
import scipy.linalg
@@ -23,7 +22,7 @@
2322

2423
import pymc3 as pm
2524

26-
from pymc3.aesaraf import floatX, rvs_to_value_vars
25+
from pymc3.aesaraf import compile_rv_inplace, floatX, rvs_to_value_vars
2726
from pymc3.blocking import DictToArrayBijection, RaveledVars
2827
from pymc3.step_methods.arraystep import (
2928
ArrayStep,
@@ -985,6 +984,6 @@ def delta_logp(point, logp, vars, shared):
985984

986985
logp1 = pm.CallableTensor(logp0)(inarray1)
987986

988-
f = aesara.function([inarray1, inarray0], logp1 - logp0)
987+
f = compile_rv_inplace([inarray1, inarray0], logp1 - logp0)
989988
f.trust_input = True
990989
return f

pymc3/tests/test_sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import pymc3 as pm
3333

34+
from pymc3.aesaraf import compile_rv_inplace
3435
from pymc3.backends.ndarray import NDArray
3536
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
3637
from pymc3.tests.helpers import SeededTest
@@ -973,7 +974,7 @@ def test_layers(self):
973974
a = pm.Uniform("a", lower=0, upper=1, size=10)
974975
b = pm.Binomial("b", n=1, p=a, size=10)
975976

976-
b_sampler = aesara.function([], b)
977+
b_sampler = compile_rv_inplace([], b, mode="FAST_RUN")
977978
avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
978979
npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
979980

0 commit comments

Comments
 (0)