diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index b263a20b9b..e2de6d9fe2 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -30,6 +30,7 @@ import numpy as np import scipy.sparse as sps +from aeppl.abstract import MeasurableVariable from aeppl.logprob import CheckParameterValue from aesara import config, scalar from aesara.compile.mode import Mode, get_mode @@ -978,14 +979,21 @@ def compile_pymc( # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph) rng_updates = {} output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] - for rv in ( - node - for node in vars_between(inputs, output_to_list) - if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs + for random_var in ( + var + for var in vars_between(inputs, output_to_list) + if var.owner + and isinstance(var.owner.op, (RandomVariable, MeasurableVariable)) + and var not in inputs ): - rng = rv.owner.inputs[0] - if not hasattr(rng, "default_update"): - rng_updates[rng] = rv.owner.outputs[0] + if isinstance(random_var.owner.op, RandomVariable): + rng = random_var.owner.inputs[0] + if not hasattr(rng, "default_update"): + rng_updates[rng] = random_var.owner.outputs[0] + else: + update_fn = getattr(random_var.owner.op, "update", None) + if update_fn is not None: + rng_updates.update(update_fn(random_var.owner)) # If called inside a model context, see if check_bounds flag is set to False try: diff --git a/pymc/distributions/bound.py b/pymc/distributions/bound.py index 90986be776..3df0ff7fda 100644 --- a/pymc/distributions/bound.py +++ b/pymc/distributions/bound.py @@ -23,7 +23,7 @@ from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous, Discrete -from pymc.distributions.logprob import logp +from pymc.distributions.logprob import ignore_logprob, logp from pymc.distributions.shape_utils import to_tuple from pymc.distributions.transforms import _default_transform from pymc.model import modelcontext @@ -193,7 +193,7 @@ def __new__( raise ValueError("Given dims do not exist in model coordinates.") lower, upper, initval = cls._set_values(lower, upper, size, shape, initval) - dist.tag.ignore_logprob = True + dist = ignore_logprob(dist) if isinstance(dist.owner.op, Continuous): res = _ContinuousBounded( @@ -228,7 +228,7 @@ def dist( cls._argument_checks(dist, **kwargs) lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None) - dist.tag.ignore_logprob = True + dist = ignore_logprob(dist) if isinstance(dist.owner.op, Continuous): res = _ContinuousBounded.dist( [dist, lower, upper], diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index c0790dee3e..8aaad5c106 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -42,9 +42,12 @@ class Censored(SymbolicDistribution): Parameters ---------- - dist: PyMC unnamed distribution - PyMC distribution created via the `.dist()` API, which will be censored. This - distribution must be univariate and have a logcdf method implemented. + dist: unnamed distribution + Univariate distribution created via the `.dist()` API, which will be censored. + This distribution must have a logcdf method implemented for sampling. + + .. warning:: dist will be cloned, rendering it independent of the one passed as input. + lower: float or None Lower (left) censoring point. If `None` the distribution will not be left censored upper: float or None diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index bb76f444e4..526ebe2548 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -20,6 +20,7 @@ import numpy as np from aeppl import factorized_joint_logprob +from aeppl.abstract import assign_custom_measurable_outputs from aeppl.logprob import logcdf as logcdf_aeppl from aeppl.logprob import logprob as logp_aeppl from aeppl.transforms import TransformValuesOpt @@ -221,7 +222,11 @@ def joint_logpt( transform_opt = TransformValuesOpt(transform_map) temp_logp_var_dict = factorized_joint_logprob( - tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs + tmp_rvs_to_values, + extra_rewrites=transform_opt, + use_jacobian=jacobian, + warn_missing_rvs=False, + **kwargs, ) # Raise if there are unexpected RandomVariables in the logp graph @@ -276,3 +281,20 @@ def logcdf(rv, value): value = at.as_tensor_variable(value, dtype=rv.dtype) return logcdf_aeppl(rv, value) + + +def ignore_logprob(rv): + """Return a duplicated variable that is ignored when creating Aeppl logprob graphs + + This is used in SymbolicDistributions that use other RVs as inputs but account + for their logp terms explicitly. + + If the variable is already ignored, it is returned directly. + """ + prefix = "Unmeasurable" + node = rv.owner + op_type = type(node.op) + if op_type.__name__.startswith(prefix): + return rv + new_node = assign_custom_measurable_outputs(node, type_prefix=prefix) + return new_node.outputs[node.outputs.index(rv)] diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 0d51915a08..b613f90bac 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -21,7 +21,7 @@ from aeppl.logprob import _logcdf, _logprob from aeppl.transforms import IntervalTransform from aesara.compile.builders import OpFromGraph -from aesara.graph.basic import equal_computations +from aesara.graph.basic import Node, equal_computations from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable @@ -30,7 +30,7 @@ from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicDistribution, _moment, moment -from pymc.distributions.logprob import logcdf, logp +from pymc.distributions.logprob import ignore_logprob, logcdf, logp from pymc.distributions.shape_utils import to_tuple from pymc.distributions.transforms import _default_transform from pymc.util import check_dist_not_registered @@ -44,6 +44,10 @@ class MarginalMixtureRV(OpFromGraph): default_output = 1 + def update(self, node: Node): + # Update for the internal mix_indexes RV + return {node.inputs[0]: node.outputs[0]} + MeasurableVariable.register(MarginalMixtureRV) @@ -66,12 +70,15 @@ class Mixture(SymbolicDistribution): w : tensor_like of float w >= 0 and w <= 1 the mixture weights - comp_dists : iterable of PyMC distributions or single batched distribution - Distributions should be created via the `.dist()` API. If single distribution is - passed, the last size dimension (not shape) determines the number of mixture + comp_dists : iterable of unnamed distributions or single batched distribution + Distributions should be created via the `.dist()` API. If a single distribution + is passed, the last size dimension (not shape) determines the number of mixture components (e.g. `pm.Poisson.dist(..., size=components)`) :math:`f_1, \ldots, f_n` + .. warning:: comp_dists will be cloned, rendering them independent of the ones passed as input. + + Examples -------- .. code-block:: python @@ -249,6 +256,10 @@ def rv_op(cls, weights, *components, size=None, rngs=None): assert weights_ndim_batch == 0 + # Component RVs terms are accounted by the Mixture logprob, so they can be + # safely ignored by Aeppl + components = [ignore_logprob(component) for component in components] + # Create a OpFromGraph that encapsulates the random generating process # Create dummy input variables with the same type as the ones provided weights_ = weights.type() @@ -287,20 +298,11 @@ def rv_op(cls, weights, *components, size=None, rngs=None): # Create the actual MarginalMixture variable mix_out = mix_op(mix_indexes_rng, weights, *components) - # We need to set_default_updates ourselves, because the choices RV is hidden - # inside OpFromGraph and PyMC will never find it otherwise - mix_indexes_rng.default_update = mix_out.owner.outputs[0] - # Reference nodes to facilitate identification in other classmethods mix_out.tag.weights = weights mix_out.tag.components = components mix_out.tag.choices_rng = mix_indexes_rng - # Component RVs terms are accounted by the Mixture logprob, so they can be - # safely ignore by Aeppl (this tag prevents UserWarning) - for component in components: - component.tag.ignore_logprob = True - return mix_out @classmethod diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 0632720cb2..dae01e805e 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -57,6 +57,7 @@ multigammaln, ) from pymc.distributions.distribution import Continuous, Discrete, moment +from pymc.distributions.logprob import ignore_logprob from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, rv_size_is_none, @@ -1182,11 +1183,9 @@ def dist(cls, eta, n, sd_dist, **kwargs): # sd_dist is part of the generative graph, but should be completely ignored # by the logp graph, since the LKJ logp explicitly includes these terms. - # Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about - # an unnacounted RandomVariable in the graph # TODO: Things could be simplified a bit if we managed to extract the # sd_dist prior components from the logp expression. - sd_dist.tag.ignore_logprob = True + sd_dist = ignore_logprob(sd_dist) return super().dist([n, eta, sd_dist], **kwargs) @@ -1271,10 +1270,13 @@ class LKJCholeskyCov: larger values put more weight on matrices with few correlations. n: int Dimension of the covariance matrix (n > 1). - sd_dist: pm.Distribution + sd_dist: unnamed distribution A positive scalar or vector distribution for the standard deviations, created with the `.dist()` API. Should have `shape[-1]=n`. Scalar distributions will be automatically resized to ensure this. + + .. warning:: sd_dist will be cloned, rendering it independent of the one passed as input. + compute_corr: bool, default=True If `True`, returns three values: the Cholesky decomposition, the correlations and the standard deviations of the covariance matrix. Otherwise, only returns diff --git a/pymc/distributions/simulator.py b/pymc/distributions/simulator.py index 8d20bfd982..22f83f12f7 100644 --- a/pymc/distributions/simulator.py +++ b/pymc/distributions/simulator.py @@ -247,8 +247,6 @@ def logp(cls, value, sim_op, sim_inputs): # in which case this would not be needed. However, that would have to be # done for every sampler that may accomodate Simulators rng = aesara.shared(np.random.default_rng()) - rng.tag.is_rng = True - # Create a new simulatorRV with identical inputs as the original one sim_value = sim_op.make_node(rng, *sim_inputs[1:]).default_output() sim_value.name = "sim_value" diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index a8fdcedec7..cc4d744e5a 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -26,6 +26,7 @@ from pymc.distributions import distribution, logprob, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters +from pymc.distributions.logprob import ignore_logprob from pymc.distributions.shape_utils import rv_size_is_none, to_tuple from pymc.util import check_dist_not_registered @@ -147,9 +148,12 @@ class GaussianRandomWalk(distribution.Continuous): innovation drift, defaults to 0.0 sigma : tensor_like of float, optional sigma > 0, innovation standard deviation, defaults to 1.0 - init : Univariate PyMC distribution + init : unnamed distribution Univariate distribution of the initial value, created with the `.dist()` API. Defaults to Normal with same `mu` and `sigma` as the GaussianRandomWalk + + .. warning:: init will be cloned, rendering them independent of the ones passed as input. + steps : int Number of steps in Gaussian Random Walks (steps > 0). """ @@ -203,7 +207,7 @@ def dist( raise TypeError("init must be a univariate distribution variable") # Ignores logprob of init var because that's accounted for in the logp method - init.tag.ignore_logprob = True + init = ignore_logprob(init) return super().dist([mu, sigma, init, steps], size=size, **kwargs) diff --git a/pymc/model.py b/pymc/model.py index 26b8ed6600..d13091be69 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1363,13 +1363,9 @@ def make_obs_var( # size of the masked and unmasked array happened to coincide _, size, _, *inps = observed_rv_var.owner.inputs rng = self.model.next_rng() - observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng) - # Add default_update to new rng - new_rng = observed_rv_var.owner.outputs[0] - observed_rv_var.update = (rng, new_rng) - rng.default_update = new_rng - observed_rv_var.name = f"{name}_observed" - + observed_rv_var = observed_rv_var.owner.op( + *inps, size=size, rng=rng, name=f"{name}_observed" + ) observed_rv_var.tag.observations = nonmissing_data self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index f1d6638286..709f21c6ef 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -22,7 +22,9 @@ import pytest import scipy.sparse as sps +from aeppl.abstract import MeasurableVariable from aeppl.logprob import ParameterValueError +from aesara.compile.builders import OpFromGraph from aesara.graph.basic import Constant, Variable, ancestors, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable @@ -596,91 +598,109 @@ def test_rvs_to_value_vars_nested(): assert equal_computations(before, after) -def test_check_bounds_flag(): - """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" - logp = at.ones(3) - cond = np.array([1, 0, 1]) - bound = check_parameters(logp, cond) - - with pm.Model() as m: - pass - - with pytest.raises(ParameterValueError): - aesara.function([], bound)() - - m.check_bounds = False - with m: - assert np.all(compile_pymc([], bound)() == 1) - - m.check_bounds = True - with m: - assert np.all(compile_pymc([], bound)() == -np.inf) - - -def test_compile_pymc_sets_rng_updates(): - rng = aesara.shared(np.random.default_rng(0)) - x = pm.Normal.dist(rng=rng) - assert x.owner.inputs[0] is rng - f = compile_pymc([], x) - assert not np.isclose(f(), f()) - - # Check that update was not done inplace - assert not hasattr(rng, "default_update") - f = aesara.function([], x) - assert f() == f() - - -def test_compile_pymc_with_updates(): - x = aesara.shared(0) - f = compile_pymc([], x, updates={x: x + 1}) - assert f() == 0 - assert f() == 1 - - -def test_compile_pymc_missing_default_explicit_updates(): - rng = aesara.shared(np.random.default_rng(0)) - x = pm.Normal.dist(rng=rng) - - # By default, compile_pymc should update the rng of x - f = compile_pymc([], x) - assert f() != f() - - # An explicit update should override the default_update, like aesara.function does - # For testing purposes, we use an update that leaves the rng unchanged - f = compile_pymc([], x, updates={rng: rng}) - assert f() == f() - - # If we specify a custom default_update directly it should use that instead. - rng.default_update = rng - f = compile_pymc([], x) - assert f() == f() - - # And again, it should be overridden by an explicit update - f = compile_pymc([], x, updates={rng: x.owner.outputs[0]}) - assert f() != f() - - -def test_compile_pymc_updates_inputs(): - """Test that compile_pymc does not include rngs updates of variables that are inputs - or ancestors to inputs - """ - x = at.random.normal() - y = at.random.normal(x) - z = at.random.normal(y) - - for inputs, rvs_in_graph in ( - ([], 3), - ([x], 2), - ([y], 1), - ([z], 0), - ([x, y], 1), - ([x, y, z], 0), - ): - fn = compile_pymc(inputs, z, on_unused_input="ignore") - fn_fgraph = fn.maker.fgraph - # Each RV adds a shared input for its rng - assert len(fn_fgraph.inputs) == len(inputs) + rvs_in_graph - # If the output is an input, the graph has a DeepCopyOp - assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1) - # Each RV adds a shared output for its rng - assert len(fn_fgraph.outputs) == 1 + rvs_in_graph +class TestCompilePyMC: + def test_check_bounds_flag(self): + """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" + logp = at.ones(3) + cond = np.array([1, 0, 1]) + bound = check_parameters(logp, cond) + + with pm.Model() as m: + pass + + with pytest.raises(ParameterValueError): + aesara.function([], bound)() + + m.check_bounds = False + with m: + assert np.all(compile_pymc([], bound)() == 1) + + m.check_bounds = True + with m: + assert np.all(compile_pymc([], bound)() == -np.inf) + + def test_compile_pymc_sets_rng_updates(self): + rng = aesara.shared(np.random.default_rng(0)) + x = pm.Normal.dist(rng=rng) + assert x.owner.inputs[0] is rng + f = compile_pymc([], x) + assert not np.isclose(f(), f()) + + # Check that update was not done inplace + assert not hasattr(rng, "default_update") + f = aesara.function([], x) + assert f() == f() + + def test_compile_pymc_with_updates(self): + x = aesara.shared(0) + f = compile_pymc([], x, updates={x: x + 1}) + assert f() == 0 + assert f() == 1 + + def test_compile_pymc_missing_default_explicit_updates(self): + rng = aesara.shared(np.random.default_rng(0)) + x = pm.Normal.dist(rng=rng) + + # By default, compile_pymc should update the rng of x + f = compile_pymc([], x) + assert f() != f() + + # An explicit update should override the default_update, like aesara.function does + # For testing purposes, we use an update that leaves the rng unchanged + f = compile_pymc([], x, updates={rng: rng}) + assert f() == f() + + # If we specify a custom default_update directly it should use that instead. + rng.default_update = rng + f = compile_pymc([], x) + assert f() == f() + + # And again, it should be overridden by an explicit update + f = compile_pymc([], x, updates={rng: x.owner.outputs[0]}) + assert f() != f() + + def test_compile_pymc_updates_inputs(self): + """Test that compile_pymc does not include rngs updates of variables that are inputs + or ancestors to inputs + """ + x = at.random.normal() + y = at.random.normal(x) + z = at.random.normal(y) + + for inputs, rvs_in_graph in ( + ([], 3), + ([x], 2), + ([y], 1), + ([z], 0), + ([x, y], 1), + ([x, y, z], 0), + ): + fn = compile_pymc(inputs, z, on_unused_input="ignore") + fn_fgraph = fn.maker.fgraph + # Each RV adds a shared input for its rng + assert len(fn_fgraph.inputs) == len(inputs) + rvs_in_graph + # If the output is an input, the graph has a DeepCopyOp + assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1) + # Each RV adds a shared output for its rng + assert len(fn_fgraph.outputs) == 1 + rvs_in_graph + + def test_compile_pymc_custom_update_op(self): + """Test that custom MeasurableVariable Op updates are used by compile_pymc""" + + class UnmeasurableOp(OpFromGraph): + def update(self, node): + return {node.inputs[0]: node.inputs[0] + 1} + + dummy_inputs = [at.scalar(), at.scalar()] + dummy_outputs = [at.add(*dummy_inputs)] + dummy_x = UnmeasurableOp(dummy_inputs, dummy_outputs)(aesara.shared(1.0), 1.0) + + # Check that there are no updates at first + fn = compile_pymc(inputs=[], outputs=dummy_x) + assert fn() == fn() == 2.0 + + # And they are enabled once the Op is registered as Measurable + MeasurableVariable.register(UnmeasurableOp) + fn = compile_pymc(inputs=[], outputs=dummy_x) + assert fn() == 2.0 + assert fn() == 3.0 diff --git a/pymc/tests/test_logprob.py b/pymc/tests/test_logprob.py index f6ff39a2cd..74952c86a3 100644 --- a/pymc/tests/test_logprob.py +++ b/pymc/tests/test_logprob.py @@ -17,6 +17,7 @@ import pytest import scipy.stats.distributions as sp +from aeppl.abstract import get_measurable_outputs from aesara.graph.basic import ancestors from aesara.tensor.random.op import RandomVariable from aesara.tensor.subtensor import ( @@ -32,7 +33,7 @@ from pymc.aesaraf import floatX, walk_model from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform from pymc.distributions.discrete import Bernoulli -from pymc.distributions.logprob import joint_logpt, logcdf, logp +from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp from pymc.model import Model, Potential from pymc.tests.helpers import select_by_precision @@ -227,3 +228,38 @@ def test_unexpected_rvs(): with pytest.raises(ValueError, match="^Random variables detected in the logp graph"): model.logpt() + + +def test_ignore_logprob_basic(): + x = Normal.dist() + (measurable_x_out,) = get_measurable_outputs(x.owner.op, x.owner) + assert measurable_x_out is x.owner.outputs[1] + + new_x = ignore_logprob(x) + assert new_x is not x + assert isinstance(new_x.owner.op, Normal) + assert type(new_x.owner.op).__name__ == "UnmeasurableNormalRV" + # Confirm that it does not have measurable output + assert get_measurable_outputs(new_x.owner.op, new_x.owner) is None + + # Test that it will not clone a variable that is already unmeasurable + new_new_x = ignore_logprob(new_x) + assert new_new_x is new_x + + +def test_ignore_logprob_model(): + # logp that does not depend on input + def logp(value, x): + return value + + with Model() as m: + x = Normal.dist() + y = DensityDist("y", x, logp=logp) + # Aeppl raises a KeyError when it finds an unexpected RV + with pytest.raises(KeyError): + joint_logpt([y], {y: y.type()}) + + with Model() as m: + x = ignore_logprob(Normal.dist()) + y = DensityDist("y", x, logp=logp) + assert joint_logpt([y], {y: y.type()})