diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index a10301fd8f..37aba7402e 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from typing import ( Callable, Dict, @@ -31,6 +33,7 @@ import scipy.sparse as sps from aeppl.logprob import CheckParameterValue +from aeppl.transforms import RVTransform from aesara import scalar from aesara.compile.mode import Mode, get_mode from aesara.gradient import grad @@ -205,10 +208,9 @@ def expand(var): yield from walk(graphs, expand, bfs=False) -def replace_rvs_in_graphs( +def _replace_rvs_in_graphs( graphs: Iterable[TensorVariable], replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]], - initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, **kwargs, ) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]: """Replace random variables in graphs @@ -226,8 +228,6 @@ def replace_rvs_in_graphs( that were made. """ replacements = {} - if initial_replacements: - replacements.update(initial_replacements) def expand_replace(var): new_nodes = [] @@ -239,6 +239,7 @@ def expand_replace(var): new_nodes.extend(replacement_fn(var, replacements)) return new_nodes + # This iteration populates the replacements for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): pass @@ -253,7 +254,15 @@ def expand_replace(var): clone=False, ) - fg.replace_all(replacements.items(), import_missing=True) + # replacements have to be done in reverse topological order so that nested + # expressions get recursively replaced correctly + toposort = fg.toposort() + sorted_replacements = sorted( + tuple(replacements.items()), + key=lambda pair: toposort.index(pair[0].owner), + reverse=True, + ) + fg.replace_all(sorted_replacements, import_missing=True) graphs = list(fg.outputs) @@ -263,7 +272,6 @@ def expand_replace(var): def rvs_to_value_vars( graphs: Iterable[Variable], apply_transforms: bool = True, - initial_replacements: Optional[Dict[Variable, Variable]] = None, **kwargs, ) -> List[Variable]: """Clone and replace random variables in graphs with their value variables. @@ -276,10 +284,11 @@ def rvs_to_value_vars( The graphs in which to perform the replacements. apply_transforms If ``True``, apply each value variable's transform. - initial_replacements - A ``dict`` containing the initial replacements to be made. - """ + warnings.warn( + "rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead", + FutureWarning, + ) def populate_replacements( random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable] @@ -311,15 +320,72 @@ def populate_replacements( equiv = clone_get_equiv(inputs, graphs, False, False, {}) graphs = [equiv[n] for n in graphs] - if initial_replacements: - initial_replacements = { - equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items() - } - - graphs, _ = replace_rvs_in_graphs( + graphs, _ = _replace_rvs_in_graphs( graphs, replacement_fn=populate_replacements, - initial_replacements=initial_replacements, + **kwargs, + ) + + return graphs + + +def replace_rvs_by_values( + graphs: Sequence[TensorVariable], + *, + rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Dict[TensorVariable, RVTransform], + **kwargs, +) -> List[TensorVariable]: + """Clone and replace random variables in graphs with their value variables. + + This will *not* recompute test values in the resulting graphs. + + Parameters + ---------- + graphs + The graphs in which to perform the replacements. + rvs_to_values + Mapping between the original graph RVs and respective value variables + rvs_to_transforms + Mapping between the original graph RVs and respective value transforms + """ + + # Clone original graphs so that we don't modify variables in place + inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] + equiv = clone_get_equiv(inputs, graphs, False, False, {}) + graphs = [equiv[n] for n in graphs] + + # Get needed mappings for equivalent cloned variables + equiv_rvs_to_values = {} + equiv_rvs_to_transforms = {} + for rv, value in rvs_to_values.items(): + equiv_rv = equiv.get(rv, rv) + equiv_rvs_to_values[equiv_rv] = equiv.get(value, value) + equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv] + + def poulate_replacements(rv, replacements): + # Populate replacements dict with {rv: value} pairs indicating which graph + # RVs should be replaced by what value variables. + + # No value variable to replace RV with + value = equiv_rvs_to_values.get(rv, None) + if value is None: + return [] + + transform = equiv_rvs_to_transforms.get(rv, None) + if transform is not None: + # We want to replace uses of the RV by the back-transformation of its value + value = transform.backward(value, *rv.owner.inputs) + value.name = rv.name + + replacements[rv] = value + # Also walk the graph of the value variable to make any additional + # replacements if that is not a simple input variable + return [value] + + graphs, _ = _replace_rvs_in_graphs( + graphs, + replacement_fn=poulate_replacements, **kwargs, ) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 1ff2c66e24..19864a5281 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} for obs in model.observed_RVs: - aux_obs = getattr(obs.tag, "observations", None) + aux_obs = model.rvs_to_values.get(obs, None) if aux_obs is not None: try: obs_data = extract_obs_data(aux_obs) @@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun): if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)): try: - obs_data = extract_obs_data(var.tag.observations) + obs_data = extract_obs_data(self.model.rvs_to_values[var]) except TypeError: warnings.warn(f"Could not extract data from symbolic observation {var}") diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 5ada4d67ce..0f362b3e67 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -16,7 +16,6 @@ logcdf, logp, joint_logp, - joint_logpt, ) from pymc.distributions.bound import Bound @@ -199,7 +198,6 @@ "Censored", "CAR", "PolyaGamma", - "joint_logpt", "joint_logp", "logp", "logcdf", diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index b75bcaaa74..13cad0cf5d 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -49,7 +49,7 @@ shape_from_dims, ) from pymc.printing import str_for_dist -from pymc.util import UNSET +from pymc.util import UNSET, _add_future_warning_tag from pymc.vartypes import string_types __all__ = [ @@ -371,6 +371,7 @@ def dist( rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)") + _add_future_warning_tag(rv_out) return rv_out diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 95c31a7c93..fa2ba09ace 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -25,26 +25,18 @@ from aeppl.logprob import logcdf as logcdf_aeppl from aeppl.logprob import logprob as logp_aeppl from aeppl.tensor import MeasurableJoin -from aeppl.transforms import TransformValuesRewrite +from aeppl.transforms import RVTransform, TransformValuesRewrite from aesara import tensor as at from aesara.graph.basic import graph_inputs, io_toposort from aesara.tensor.random.op import RandomVariable -from aesara.tensor.subtensor import ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, -) from aesara.tensor.var import TensorVariable from pymc.aesaraf import constant_fold, floatX +TOTAL_SIZE = Union[int, Sequence[int], None] -def _get_scaling( - total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int -) -> TensorVariable: + +def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable: """ Gets scaling constant for logp. @@ -112,22 +104,26 @@ def _get_scaling( return at.as_tensor(coef, dtype=aesara.config.floatX) -subtensor_types = ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, -) - +def _check_no_rvs(logp_terms: Sequence[TensorVariable]): + # Raise if there are unexpected RandomVariables in the logp graph + # Only SimulatorRVs are allowed + from pymc.distributions.simulator import SimulatorRV -def joint_logpt(*args, **kwargs): - warnings.warn( - "joint_logpt has been deprecated. Use joint_logp instead.", - FutureWarning, - ) - return joint_logp(*args, **kwargs) + unexpected_rv_nodes = [ + node + for node in aesara.graph.ancestors(logp_terms) + if ( + node.owner + and isinstance(node.owner.op, RandomVariable) + and not isinstance(node.owner.op, SimulatorRV) + ) + ] + if unexpected_rv_nodes: + raise ValueError( + f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n" + "This can happen when DensityDist logp or Interval transform functions " + "reference nonlocal variables." + ) def joint_logp( @@ -169,6 +165,10 @@ def joint_logp( Sum the log-likelihood or return each term as a separate list item. """ + warnings.warn( + "joint_logp has been deprecated, use model.logp instead", + FutureWarning, + ) # TODO: In future when we drop support for tag.value_var most of the following # logic can be removed and logp can just be a wrapper function that calls aeppl's # joint_logprob directly. @@ -241,26 +241,6 @@ def joint_logp( **kwargs, ) - # Raise if there are unexpected RandomVariables in the logp graph - # Only SimulatorRVs are allowed - from pymc.distributions.simulator import SimulatorRV - - unexpected_rv_nodes = [ - node - for node in aesara.graph.ancestors(list(temp_logp_var_dict.values())) - if ( - node.owner - and isinstance(node.owner.op, RandomVariable) - and not isinstance(node.owner.op, SimulatorRV) - ) - ] - if unexpected_rv_nodes: - raise ValueError( - f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n" - "This can happen when DensityDist logp or Interval transform functions " - "reference nonlocal variables." - ) - # aeppl returns the logp for every single value term we provided to it. This includes # the extra values we plugged in above, so we filter those we actually wanted in the # same order they were given in. @@ -268,6 +248,8 @@ def joint_logp( for value_var in rv_values.values(): logp_var_dict[value_var] = temp_logp_var_dict[value_var] + _check_no_rvs(list(logp_var_dict.values())) + if scaling: for value_var in logp_var_dict.keys(): if value_var in rv_scalings: @@ -281,6 +263,52 @@ def joint_logp( return logp_var +def _joint_logp( + rvs: Sequence[TensorVariable], + *, + rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Dict[TensorVariable, RVTransform], + jacobian: bool = True, + rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE], + **kwargs, +) -> List[TensorVariable]: + """Thin wrapper around aeppl.factorized_joint_logprob, extended with PyMC specific + concerns such as transforms, jacobian, and scaling""" + + transform_rewrite = None + values_to_transforms = { + rvs_to_values[rv]: transform + for rv, transform in rvs_to_transforms.items() + if transform is not None + } + if values_to_transforms: + # There seems to be an incorrect type hint in TransformValuesRewrite + transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore + + temp_logp_terms = factorized_joint_logprob( + rvs_to_values, + extra_rewrites=transform_rewrite, + use_jacobian=jacobian, + **kwargs, + ) + + # aeppl returns the logp for every single value term we provided to it. This includes + # the extra values we plugged in above, so we filter those we actually wanted in the + # same order they were given in. + logp_terms = {} + for rv in rvs: + value_var = rvs_to_values[rv] + logp_term = temp_logp_terms[value_var] + total_size = rvs_to_total_sizes.get(rv, None) + if total_size is not None: + scaling = _get_scaling(total_size, value_var.shape, value_var.ndim) + logp_term *= scaling + logp_terms[value_var] = logp_term + + _check_no_rvs(list(logp_terms.values())) + return list(logp_terms.values()) + + def logp(rv: TensorVariable, value) -> TensorVariable: """Return the log-probability graph of a Random Variable""" diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index c9cf4bc6f8..666cde5c1d 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -50,6 +50,7 @@ from pymc.aesaraf import PotentialShapeType from pymc.exceptions import ShapeError +from pymc.util import _add_future_warning_tag def to_tuple(shape): @@ -600,6 +601,7 @@ def change_dist_size( new_size = tuple(new_size) # type: ignore new_dist = _change_dist_size(dist.owner.op, dist, new_size=new_size, expand=expand) + _add_future_warning_tag(new_dist) new_dist.name = dist.name for k, v in dist.tag.__dict__.items(): diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 7b09f856b1..2b5d28dd51 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,6 +20,7 @@ import aesara.tensor as at import numpy as np +from aeppl.transforms import RVTransform from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.var import TensorVariable @@ -43,9 +44,7 @@ def convert_str_to_rv_dict( if isinstance(key, str): if is_transformed_name(key): rv = model[get_untransformed_name(key)] - initvals[rv] = model.rvs_to_values[rv].tag.transform.backward( - initval, *rv.owner.inputs - ) + initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs) else: initvals[model[key]] = initval else: @@ -158,7 +157,7 @@ def make_initial_point_fn( initial_values = make_initial_point_expression( free_rvs=model.free_RVs, - rvs_to_values=model.rvs_to_values, + rvs_to_transforms=model.rvs_to_transforms, initval_strategies=initval_strats, jitter_rvs=jitter_rvs, default_strategy=default_strategy, @@ -172,7 +171,7 @@ def make_initial_point_fn( varnames = [] for var in model.free_RVs: - transform = getattr(model.rvs_to_values[var].tag, "transform", None) + transform = model.rvs_to_transforms[var] if transform is not None and return_transformed: name = get_transformed_name(var.name, transform) else: @@ -197,7 +196,7 @@ def inner(seed, *args, **kwargs): def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], - rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Dict[TensorVariable, RVTransform], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "moment", @@ -265,7 +264,7 @@ def make_initial_point_expression( else: value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) - transform = getattr(rvs_to_values[variable].tag, "transform", None) + transform = rvs_to_transforms.get(variable, None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) diff --git a/pymc/model.py b/pymc/model.py index 248c7be462..c74375c00f 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -56,18 +56,18 @@ gradient, hessian, inputvars, - rvs_to_value_vars, + replace_rvs_by_values, ) from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch -from pymc.distributions import joint_logp -from pymc.distributions.logprob import _get_scaling +from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning from pymc.initial_point import make_initial_point_fn from pymc.util import ( UNSET, WithMemoization, + _add_future_warning_tag, get_transformed_name, get_value_vars_from_user_vars, get_var_name, @@ -555,6 +555,8 @@ def __init__( self.named_vars = treedict(parent=self.parent.named_vars) self.values_to_rvs = treedict(parent=self.parent.values_to_rvs) self.rvs_to_values = treedict(parent=self.parent.rvs_to_values) + self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms) + self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes) self.free_RVs = treelist(parent=self.parent.free_RVs) self.observed_RVs = treelist(parent=self.parent.observed_RVs) self.auto_deterministics = treelist(parent=self.parent.auto_deterministics) @@ -567,6 +569,8 @@ def __init__( self.named_vars = treedict() self.values_to_rvs = treedict() self.rvs_to_values = treedict() + self.rvs_to_transforms = treedict() + self.rvs_to_total_sizes = treedict() self.free_RVs = treelist() self.observed_RVs = treelist() self.auto_deterministics = treelist() @@ -725,13 +729,13 @@ def logp( # We need to separate random variables from potential terms, and remember their # original order so that we can merge them together in the same order at the end - rv_values = {} + rvs = [] potentials = [] rv_order, potential_order = [], [] for i, var in enumerate(varlist): - value_var = self.rvs_to_values.get(var) - if value_var is not None: - rv_values[var] = value_var + rv = self.values_to_rvs.get(var, var) + if rv in self.basic_RVs: + rvs.append(rv) rv_order.append(i) else: if var in self.potentials: @@ -743,14 +747,20 @@ def logp( ) rv_logps: List[TensorVariable] = [] - if rv_values: - rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian) + if rvs: + rv_logps = _joint_logp( + rvs=rvs, + rvs_to_values=self.rvs_to_values, + rvs_to_transforms=self.rvs_to_transforms, + rvs_to_total_sizes=self.rvs_to_total_sizes, + jacobian=jacobian, + ) assert isinstance(rv_logps, list) # Replace random variables by their value variables in potential terms potential_logps = [] if potentials: - potential_logps = rvs_to_value_vars(potentials) + potential_logps = self.replace_rvs_by_values(potentials) logp_factors = [None] * len(varlist) for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)): @@ -870,7 +880,7 @@ def potentiallogp(self) -> Variable: """Aesara scalar of log-probability of the Potential terms""" # Convert random variables in Potential expression into their log-likelihood # inputs and apply their transforms, if any - potentials = rvs_to_value_vars(self.potentials) + potentials = self.replace_rvs_by_values(self.potentials) if potentials: return at.sum([at.sum(factor) for factor in potentials]) else: @@ -890,23 +900,19 @@ def unobserved_value_vars(self): log-likelihood graph """ vars = [] - untransformed_vars = [] + transformed_rvs = [] for rv in self.free_RVs: value_var = self.rvs_to_values[rv] - transform = getattr(value_var.tag, "transform", None) + transform = self.rvs_to_transforms[rv] if transform is not None: - # We need to create and add an un-transformed version of - # each transformed variable - untrans_value_var = transform.backward(value_var, *rv.owner.inputs) - untrans_value_var.name = rv.name - untransformed_vars.append(untrans_value_var) + transformed_rvs.append(rv) vars.append(value_var) # Remove rvs from untransformed values graph - untransformed_vars = rvs_to_value_vars(untransformed_vars) + untransformed_vars = self.replace_rvs_by_values(transformed_rvs) # Remove rvs from deterministics graph - deterministics = rvs_to_value_vars(self.deterministics) + deterministics = self.replace_rvs_by_values(self.deterministics) return vars + untransformed_vars + deterministics @@ -944,7 +950,7 @@ def basic_RVs(self): These are the actual random variable terms that make up the "sample-space" graph (i.e. you can sample these graphs by compiling them with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. + use `model.value_vars` instead. """ return self.free_RVs + self.observed_RVs @@ -955,7 +961,7 @@ def unobserved_RVs(self): These are the actual random variable terms that make up the "sample-space" graph (i.e. you can sample these graphs by compiling them with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. + use `var.unobserved_value_vars` instead. """ return self.free_RVs + self.deterministics @@ -980,17 +986,6 @@ def dim_lengths(self) -> Dict[str, Variable]: """ return self._dim_lengths - @property - def unobserved_RVs(self): - """List of all random variables, including deterministic ones. - - These are the actual random variable terms that make up the - "sample-space" graph (i.e. you can sample these graphs by compiling them - with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. - """ - return self.free_RVs + self.deterministics - @property def test_point(self) -> Dict[str, np.ndarray]: """Deprecated alias for `Model.initial_point(seed=None)`.""" @@ -1320,8 +1315,9 @@ def register_rv( """ name = self.name_for(name) rv_var.name = name + _add_future_warning_tag(rv_var) rv_var.tag.total_size = total_size - rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim) + self.rvs_to_total_sizes[rv_var] = total_size # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1389,7 +1385,7 @@ def make_obs_var( if test_value is not None: # We try to reuse the old test value - rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.tag.test_value.shape) + rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.shape) else: rv_var.tag.test_value = data @@ -1501,6 +1497,7 @@ def create_value_var( if aesara.config.compute_test_value != "off": value_var.tag.test_value = rv_var.tag.test_value + _add_future_warning_tag(value_var) rv_var.tag.value_var = value_var # Make the value variable a transformed value variable, @@ -1516,7 +1513,7 @@ def create_value_var( value_var, *rv_var.owner.inputs ).tag.test_value self.named_vars[value_var.name] = value_var - + self.rvs_to_transforms[rv_var] = transform self.rvs_to_values[rv_var] = value_var self.values_to_rvs[value_var] = rv_var @@ -1582,9 +1579,29 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.named_vars or self.name_for(key) in self.named_vars + def replace_rvs_by_values( + self, + graphs: Sequence[TensorVariable], + **kwargs, + ) -> List[TensorVariable]: + """Clone and replace random variables in graphs with their value variables. + + This will *not* recompute test values in the resulting graphs. + + Parameters + ---------- + graphs + The graphs in which to perform the replacements. + """ + return replace_rvs_by_values( + graphs, + rvs_to_values=self.rvs_to_values, + rvs_to_transforms=self.rvs_to_transforms, + ) + def compile_fn( self, - outs: Sequence[Variable], + outs: Union[Variable, Sequence[Variable]], *, inputs: Optional[Sequence[Variable]] = None, mode=None, @@ -1679,8 +1696,7 @@ def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: names = [] outputs = [] for rv in self.free_RVs: - rv_var = self.rvs_to_values[rv] - transform = getattr(rv_var.tag, "transform", None) + transform = self.rvs_to_transforms[rv] if transform is not None: names.append(get_transformed_name(rv.name, transform)) outputs.append(transform.forward(rv, *rv.owner.inputs).shape) @@ -1689,7 +1705,7 @@ def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: f = aesara.function( inputs=[], outputs=outputs, - givens=[(obs, obs.tag.observations) for obs in self.observed_RVs], + givens=[(obs, self.rvs_to_values[obs]) for obs in self.observed_RVs], mode=aesara.compile.mode.FAST_COMPILE, on_unused_input="ignore", ) @@ -1851,15 +1867,22 @@ def set_data(new_data, model=None, *, coords=None): def compile_fn( - outs, mode=None, point_fn: bool = True, model: Optional[Model] = None, **kwargs + outs: Union[Variable, Sequence[Variable]], + *, + inputs: Optional[Sequence[Variable]] = None, + mode=None, + point_fn: bool = True, + model: Optional[Model] = None, + **kwargs, ) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: - """Compiles an Aesara function which returns ``outs`` and takes values of model - vars as a dict as an argument. + """Compiles an Aesara function Parameters ---------- outs Aesara variable or iterable of Aesara variables. + inputs + Aesara input variables, defaults to aesaraf.inputvars(outs). mode Aesara compilation mode, default=None. point_fn : bool @@ -1870,10 +1893,17 @@ def compile_fn( Returns ------- - Compiled Aesara function as point function. + Compiled Aesara function """ + model = modelcontext(model) - return model.compile_fn(outs, mode=mode, point_fn=point_fn, **kwargs) + return model.compile_fn( + outs, + inputs=inputs, + mode=mode, + point_fn=point_fn, + **kwargs, + ) def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]: @@ -1998,7 +2028,6 @@ def Potential(name, var, model=None): """ model = modelcontext(model) var.name = model.name_for(name) - var.tag.scaling = 1.0 model.potentials.append(var) model.add_random_variable(var) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index f5dd3807e9..93eed6ce01 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -79,8 +79,8 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va raise ValueError(f"{var_name} is not in this model.") for model_var in self.var_list: - if hasattr(model_var.tag, "observations"): - if model_var.tag.observations == self.model[var_name]: + if model_var in self.model.observed_RVs: + if self.model.rvs_to_values[model_var] == self.model[var_name]: selected_names.add(model_var.name) selected_ancestors = set( @@ -91,8 +91,8 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va ) for var in selected_ancestors.copy(): - if hasattr(var.tag, "observations"): - selected_ancestors.add(var.tag.observations) + if var in self.model.observed_RVs: + selected_ancestors.add(self.model.rvs_to_values[var]) # ordering of self._all_var_names is important return [var.name for var in selected_ancestors] @@ -108,8 +108,8 @@ def make_compute_graph( parent_name = self.get_parent_names(var) input_map[var_name] = input_map[var_name].union(parent_name) - if hasattr(var.tag, "observations"): - obs_node = var.tag.observations + if var in self.model.observed_RVs: + obs_node = self.model.rvs_to_values[var] # loop created so that the elif block can go through this again # and remove any intermediate ops, notably dtype casting, to observations diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index a5450f951a..ed91b77f3f 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -388,7 +388,7 @@ def sample_prior_predictive( for name in sorted(missing_names): transformed_value_var = model[name] rv_var = model.values_to_rvs[transformed_value_var] - transform = transformed_value_var.tag.transform + transform = model.rvs_to_transforms[rv_var] transformed_rv_var = transform.forward(rv_var, *rv_var.owner.inputs) names.append(name) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ac04f0f465..f111eb0d87 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -210,11 +210,8 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: def all_continuous(vars): - """Check that vars not include discrete variables, excepting observed RVs.""" - - vars_ = [var for var in vars if not hasattr(var.tag, "observations")] - - if any([(var.dtype in discrete_types) for var in vars_]): + """Check that vars not include discrete variables""" + if any([(var.dtype in discrete_types) for var in vars]): return False else: return True diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 62f556251b..dcd4b3b6d5 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -31,7 +31,6 @@ floatX, join_nonshared_inputs, replace_rng_nodes, - rvs_to_value_vars, ) from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.step_methods.arraystep import ( @@ -585,7 +584,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): if isinstance(distr, CategoricalRV): k_graph = rv_var.owner.inputs[3].shape[-1] - (k_graph,) = rvs_to_value_vars((k_graph,)) + (k_graph,) = model.replace_rvs_by_values((k_graph,)) k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")( initial_point ) diff --git a/pymc/tests/backends/fixtures.py b/pymc/tests/backends/fixtures.py index 287ac79dae..a1ac8ce89e 100644 --- a/pymc/tests/backends/fixtures.py +++ b/pymc/tests/backends/fixtures.py @@ -143,9 +143,9 @@ def setup_class(cls): cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape) if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True: - cls.chain_vars = [v.tag.value_var for v in cls.model.unobserved_RVs[1:]] + cls.chain_vars = [cls.model.rvs_to_values[v] for v in cls.model.unobserved_RVs[1:]] else: - cls.chain_vars = [v.tag.value_var for v in cls.model.unobserved_RVs] + cls.chain_vars = [cls.model.rvs_to_values[v] for v in cls.model.unobserved_RVs] with cls.model: strace0 = cls.backend(cls.name, vars=cls.chain_vars) diff --git a/pymc/tests/distributions/test_bound.py b/pymc/tests/distributions/test_bound.py index 57773b2577..905aaa683d 100644 --- a/pymc/tests/distributions/test_bound.py +++ b/pymc/tests/distributions/test_bound.py @@ -22,8 +22,6 @@ import pymc as pm -from pymc.distributions import joint_logp - class TestBound: """Tests for pm.Bound distribution""" @@ -47,29 +45,38 @@ def test_continuous(self): UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10) BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10) - assert joint_logp(LowerNormal, -1).eval() == -np.inf - assert joint_logp(UpperNormal, 1).eval() == -np.inf - assert joint_logp(BoundedNormal, 0).eval() == -np.inf - assert joint_logp(BoundedNormal, 11).eval() == -np.inf + assert model.compile_fn(model.logp(LowerNormal), point_fn=False)(-1) == -np.inf + assert model.compile_fn(model.logp(UpperNormal), point_fn=False)(1) == -np.inf + assert model.compile_fn(model.logp(BoundedNormal), point_fn=False)(0) == -np.inf + assert model.compile_fn(model.logp(BoundedNormal), point_fn=False)(11) == -np.inf - assert joint_logp(UnboundedNormal, 0).eval() != -np.inf - assert joint_logp(UnboundedNormal, 11).eval() != -np.inf - assert joint_logp(InfBoundedNormal, 0).eval() != -np.inf - assert joint_logp(InfBoundedNormal, 11).eval() != -np.inf + assert model.compile_fn(model.logp(UnboundedNormal), point_fn=False)(0) != -np.inf + assert model.compile_fn(model.logp(UnboundedNormal), point_fn=False)(11) != -np.inf + assert model.compile_fn(model.logp(InfBoundedNormal), point_fn=False)(0) != -np.inf + assert model.compile_fn(model.logp(InfBoundedNormal), point_fn=False)(11) != -np.inf - value = model.rvs_to_values[LowerNormalTransform] - assert joint_logp(LowerNormalTransform, value).eval({value: -1}) != -np.inf - value = model.rvs_to_values[UpperNormalTransform] - assert joint_logp(UpperNormalTransform, value).eval({value: 1}) != -np.inf - value = model.rvs_to_values[BoundedNormalTransform] - assert joint_logp(BoundedNormalTransform, value).eval({value: 0}) != -np.inf - assert joint_logp(BoundedNormalTransform, value).eval({value: 11}) != -np.inf + assert model.compile_fn(model.logp(LowerNormalTransform), point_fn=False)(-1) != -np.inf + assert model.compile_fn(model.logp(UpperNormalTransform), point_fn=False)(1) != -np.inf + assert model.compile_fn(model.logp(BoundedNormalTransform), point_fn=False)(0) != -np.inf + assert model.compile_fn(model.logp(BoundedNormalTransform), point_fn=False)(11) != -np.inf ref_dist = pm.Normal.dist(mu=0, sigma=1) - assert np.allclose(joint_logp(UnboundedNormal, 5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(LowerNormal, 5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(UpperNormal, -5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(BoundedNormal, 5).eval(), joint_logp(ref_dist, 5).eval()) + assert np.allclose( + model.compile_fn(model.logp(UnboundedNormal), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(LowerNormal), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(UpperNormal), point_fn=False)(-5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(BoundedNormal), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) def test_discrete(self): with pm.Model() as model: @@ -84,19 +91,31 @@ def test_discrete(self): UpperPoisson = pm.Bound("upper", dist, upper=10) BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10) - assert joint_logp(LowerPoisson, 0).eval() == -np.inf - assert joint_logp(UpperPoisson, 11).eval() == -np.inf - assert joint_logp(BoundedPoisson, 0).eval() == -np.inf - assert joint_logp(BoundedPoisson, 11).eval() == -np.inf + assert model.compile_fn(model.logp(LowerPoisson), point_fn=False)(0) == -np.inf + assert model.compile_fn(model.logp(UpperPoisson), point_fn=False)(11) == -np.inf + assert model.compile_fn(model.logp(BoundedPoisson), point_fn=False)(0) == -np.inf + assert model.compile_fn(model.logp(BoundedPoisson), point_fn=False)(11) == -np.inf - assert joint_logp(UnboundedPoisson, 0).eval() != -np.inf - assert joint_logp(UnboundedPoisson, 11).eval() != -np.inf + assert model.compile_fn(model.logp(UnboundedPoisson), point_fn=False)(0) != -np.inf + assert model.compile_fn(model.logp(UnboundedPoisson), point_fn=False)(11) != -np.inf ref_dist = pm.Poisson.dist(mu=4) - assert np.allclose(joint_logp(UnboundedPoisson, 5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(LowerPoisson, 5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(UpperPoisson, 5).eval(), joint_logp(ref_dist, 5).eval()) - assert np.allclose(joint_logp(BoundedPoisson, 5).eval(), joint_logp(ref_dist, 5).eval()) + assert np.allclose( + model.compile_fn(model.logp(UnboundedPoisson), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(LowerPoisson), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(UpperPoisson), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) + assert np.allclose( + model.compile_fn(model.logp(BoundedPoisson), point_fn=False)(5), + pm.logp(ref_dist, 5).eval(), + ) def create_invalid_distribution(self): class MyNormal(RandomVariable): @@ -220,18 +239,26 @@ def test_array_bound(self): "bounded", dist, lower=[1, 2], upper=[9, 10], transform=None ) - first, second = joint_logp(LowerPoisson, [0, 0], sum=False)[0].eval() + first, second = model.compile_fn(model.logp(LowerPoisson, sum=False)[0], point_fn=False)( + [0, 0] + ) assert first == -np.inf assert second != -np.inf - first, second = joint_logp(UpperPoisson, [11, 11], sum=False)[0].eval() + first, second = model.compile_fn(model.logp(UpperPoisson, sum=False)[0], point_fn=False)( + [11, 11] + ) assert first != -np.inf assert second == -np.inf - first, second = joint_logp(BoundedPoisson, [1, 1], sum=False)[0].eval() + first, second = model.compile_fn(model.logp(BoundedPoisson, sum=False)[0], point_fn=False)( + [1, 1] + ) assert first != -np.inf assert second == -np.inf - first, second = joint_logp(BoundedPoisson, [10, 10], sum=False)[0].eval() + first, second = model.compile_fn(model.logp(BoundedPoisson, sum=False)[0], point_fn=False)( + [10, 10] + ) assert first == -np.inf assert second != -np.inf diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 8e46228e8c..952b6087dd 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -73,7 +73,7 @@ def get_dist_params_and_interval_bounds(self, model, rv_name): interval_rv = model.named_vars[f"{rv_name}_interval__"] rv = model.named_vars[rv_name] dist_params = rv.owner.inputs - lower_interval, upper_interval = interval_rv.tag.transform.args_fn(*rv.owner.inputs) + lower_interval, upper_interval = model.rvs_to_transforms[rv].args_fn(*rv.owner.inputs) return ( dist_params, lower_interval, diff --git a/pymc/tests/distributions/test_discrete.py b/pymc/tests/distributions/test_discrete.py index aedfb2925f..23229bdd74 100644 --- a/pymc/tests/distributions/test_discrete.py +++ b/pymc/tests/distributions/test_discrete.py @@ -30,7 +30,7 @@ import pymc as pm from pymc.aesaraf import floatX -from pymc.distributions import joint_logp, logcdf, logp +from pymc.distributions import logcdf, logp from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit from pymc.tests.distributions.util import ( BaseTestDistributionRandom, @@ -574,8 +574,8 @@ def test_orderedlogistic_dimensions(shape): p=p, observed=obs, ) - ologp = joint_logp(ol, np.ones_like(obs), sum=True).eval() * loge - clogp = joint_logp(c, np.ones_like(obs), sum=True).eval() * loge + ologp = pm.logp(ol, np.ones_like(obs)).sum().eval() * loge + clogp = pm.logp(c, np.ones_like(obs)).sum().eval() * loge expected = -np.prod((size,) + shape) assert c.owner.inputs[3].ndim == (len(shape) + 1) diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index 27ac0d3d63..73810f3fd9 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -26,10 +26,11 @@ import pymc as pm -from pymc.distributions import DiracDelta, Flat, MvNormal, MvStudentT, joint_logp, logp +from pymc.distributions import DiracDelta, Flat, MvNormal, MvStudentT, logp from pymc.distributions.distribution import SymbolicRandomVariable, _moment, moment -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.tests.distributions.util import assert_moment_is_expected +from pymc.util import _FutureWarningValidatingScratchpad class TestBugfixes: @@ -215,14 +216,13 @@ def logp(value, mu): mu = pm.Normal("mu", size=supp_shape) a = pm.DensityDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) - mu_val = npr.normal(loc=0, scale=1, size=supp_shape).astype(aesara.config.floatX) - a_val = npr.normal(loc=mu_val, scale=1, size=to_tuple(size) + (supp_shape,)).astype( - aesara.config.floatX - ) - log_densityt = joint_logp(a, a.tag.value_var, sum=False)[0] - assert log_densityt.eval( - {a.tag.value_var: a_val, mu.tag.value_var: mu_val}, - ).shape == to_tuple(size) + + mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(aesara.config.floatX) + a_test_value = npr.normal( + loc=mu_test_value, scale=1, size=to_tuple(size) + (supp_shape,) + ).astype(aesara.config.floatX) + log_densityf = model.compile_logp(vars=[a], sum=False) + assert log_densityf({"a": a_test_value, "mu": mu_test_value})[0].shape == to_tuple(size) @pytest.mark.parametrize( "moment, size, expected", @@ -359,3 +359,40 @@ class TestSymbolicRV(SymbolicRandomVariable): dirac_delta_2_ = DiracDelta.dist(10) node = TestSymbolicRV([], [dirac_delta_1_, dirac_delta_2_], ndim_supp=0)().owner assert get_measurable_outputs(node.op, node) == [node.outputs[default_output_idx]] + + +def test_tag_future_warning_dist(): + # Test no unexpected warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + + x = pm.Normal.dist() + assert isinstance(x.tag, _FutureWarningValidatingScratchpad) + + x.tag.banana = "banana" + assert x.tag.banana == "banana" + + # Check we didn't break test_value filtering + x.tag.test_value = np.array(1) + assert x.tag.test_value == 1 + with pytest.raises(TypeError, match="Wrong number of dimensions"): + x.tag.test_value = np.array([1, 1]) + assert x.tag.test_value == 1 + + # No warning if deprecated attribute is not present + with pytest.raises(AttributeError): + x.tag.value_var + + # Warning if present + x.tag.value_var = "1" + with pytest.warns(FutureWarning, match="Use model.rvs_to_values"): + value_var = x.tag.value_var + assert value_var == "1" + + # Check that PyMC method that copies tag contents does not erase special tag + new_x = change_dist_size(x, new_size=5) + assert new_x.tag is not x.tag + assert isinstance(new_x.tag, _FutureWarningValidatingScratchpad) + with pytest.warns(FutureWarning, match="Use model.rvs_to_values"): + value_var = new_x.tag.value_var + assert value_var == "1" diff --git a/pymc/tests/distributions/test_logprob.py b/pymc/tests/distributions/test_logprob.py index 4212b4baa7..8ca6ee97f4 100644 --- a/pymc/tests/distributions/test_logprob.py +++ b/pymc/tests/distributions/test_logprob.py @@ -45,19 +45,13 @@ from pymc.distributions.discrete import Bernoulli from pymc.distributions.logprob import ( _get_scaling, + _joint_logp, ignore_logprob, - joint_logp, - joint_logpt, logcdf, logp, ) from pymc.model import Model, Potential -from pymc.tests.helpers import select_by_precision - - -def assert_no_rvs(var): - assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner) - return var +from pymc.tests.helpers import assert_no_rvs, select_by_precision def test_get_scaling(): @@ -117,25 +111,24 @@ def test_joint_logp_basic(): b = Uniform("b", b_l, b_l + 1.0) a_value_var = m.rvs_to_values[a] - assert a_value_var.tag.transform + assert m.rvs_to_transforms[a] b_value_var = m.rvs_to_values[b] - assert b_value_var.tag.transform + assert m.rvs_to_transforms[b] c_value_var = m.rvs_to_values[c] - b_logp = joint_logp(b, b_value_var, sum=False) - - with pytest.warns(FutureWarning): - b_logpt = joint_logpt(b, b_value_var, sum=False) - - res_ancestors = list(walk_model(b_logp)) - res_rv_ancestors = [ - v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) - ] + (b_logp,) = _joint_logp( + (b,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + rvs_to_total_sizes={}, + ) # There shouldn't be any `RandomVariable`s in the resulting graph - assert len(res_rv_ancestors) == 0 + assert_no_rvs(b_logp) + + res_ancestors = list(walk_model((b_logp,))) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors @@ -171,7 +164,12 @@ def test_joint_logp_incsubtensor(indices, size): a_idx_value_var = a_idx.type() a_idx_value_var.name = "a_idx_value" - a_idx_logp = joint_logp(a_idx, {a_idx: a_value_var}, sum=False) + a_idx_logp = _joint_logp( + (a_idx,), + rvs_to_values={a_idx: a_value_var}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + ) logp_vals = a_idx_logp[0].eval({a_value_var: a_val}) @@ -213,7 +211,12 @@ def test_joint_logp_subtensor(): I_value_var = I_rv.type() I_value_var.name = "I_value" - A_idx_logps = joint_logp(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) + A_idx_logps = _joint_logp( + (A_idx, I_rv), + rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + ) A_idx_logp = at.add(*A_idx_logps) logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp) @@ -342,7 +345,12 @@ def logp(value, x): match="Found a random variable that was neither among the observations " "nor the conditioned variables", ): - assert joint_logp([y], {y: y.type()}) + _joint_logp( + [y], + rvs_to_values={y: y.type()}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + ) # The above warning should go away with ignore_logprob. with Model() as m: @@ -350,7 +358,12 @@ def logp(value, x): y = DensityDist("y", x, logp=logp) with warnings.catch_warnings(): warnings.simplefilter("error") - assert joint_logp([y], {y: y.type()}) + assert _joint_logp( + [y], + rvs_to_values={y: y.type()}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + ) def test_hierarchical_logp(): @@ -363,8 +376,8 @@ def test_hierarchical_logp(): ops = {a.owner.op for a in logp_ancestors if a.owner} assert len(ops) > 0 assert not any(isinstance(o, RandomVariable) for o in ops) - assert x.tag.value_var in logp_ancestors - assert y.tag.value_var in logp_ancestors + assert m.rvs_to_values[x] in logp_ancestors + assert m.rvs_to_values[y] in logp_ancestors def test_hierarchical_obs_logp(): diff --git a/pymc/tests/distributions/test_simulator.py b/pymc/tests/distributions/test_simulator.py index 5bc5b498ed..9576f9afc7 100644 --- a/pymc/tests/distributions/test_simulator.py +++ b/pymc/tests/distributions/test_simulator.py @@ -196,15 +196,8 @@ def test_multiple_simulators(self): assert self.count_rvs(m.logp()) == 2 # Check that the logps use the correct methods - a_val = m.rvs_to_values[a] - sim1_val = m.rvs_to_values[sim1] - logp_sim1 = pm.joint_logp(sim1, sim1_val) - logp_sim1_fn = aesara.function([a_val], logp_sim1) - - b_val = m.rvs_to_values[b] - sim2_val = m.rvs_to_values[sim2] - logp_sim2 = pm.joint_logp(sim2, sim2_val) - logp_sim2_fn = aesara.function([b_val], logp_sim2) + logp_sim1_fn = m.compile_fn(m.logp(sim1), point_fn=False) + logp_sim2_fn = m.compile_fn(m.logp(sim2), point_fn=False) assert any( node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp) diff --git a/pymc/tests/distributions/test_transform.py b/pymc/tests/distributions/test_transform.py index 71f4868d7a..850b285f4d 100644 --- a/pymc/tests/distributions/test_transform.py +++ b/pymc/tests/distributions/test_transform.py @@ -24,7 +24,7 @@ import pymc.distributions.transforms as tr from pymc.aesaraf import floatX, jacobian -from pymc.distributions import joint_logp +from pymc.distributions.logprob import _joint_logp from pymc.tests.checks import close_to, close_to_logical from pymc.tests.distributions.util import ( Circ, @@ -276,32 +276,49 @@ def build_model(self, distfam, params, size, transform, initval=None): def check_transform_elementwise_logp(self, model): x = model.free_RVs[0] - x_val_transf = x.tag.value_var + x_val_transf = model.rvs_to_values[x] pt = model.initial_point(0) test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) - transform = x_val_transf.tag.transform + transform = model.rvs_to_transforms[x] test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() # Create input variable with same dimensionality as untransformed test_array x_val_untransf = at.constant(test_array_untransf).type() jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) - assert joint_logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim - - v1 = joint_logp(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) - v2 = joint_logp(x, x_val_untransf, transformed=False).eval( - {x_val_untransf: test_array_untransf} + assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim + + v1 = ( + _joint_logp( + (x,), + rvs_to_values={x: x_val_transf}, + rvs_to_transforms={x: transform}, + rvs_to_total_sizes={}, + jacobian=False, + )[0] + .sum() + .eval({x_val_transf: test_array_transf}) + ) + v2 = ( + _joint_logp( + (x,), + rvs_to_values={x: x_val_untransf}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + )[0] + .sum() + .eval({x_val_untransf: test_array_untransf}) ) close_to(v1, v2, tol) def check_vectortransform_elementwise_logp(self, model): x = model.free_RVs[0] - x_val_transf = x.tag.value_var + x_val_transf = model.rvs_to_values[x] pt = model.initial_point(0) test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) - transform = x_val_transf.tag.transform + transform = model.rvs_to_transforms[x] test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() # Create input variable with same dimensionality as untransformed test_array @@ -310,14 +327,31 @@ def check_vectortransform_elementwise_logp(self, model): jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) # Original distribution is univariate if x.owner.op.ndim_supp == 0: - assert joint_logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) + assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) # Original distribution is multivariate else: - assert joint_logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim - - a = joint_logp(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) - b = joint_logp(x, x_val_untransf, transformed=False).eval( - {x_val_untransf: test_array_untransf} + assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim + + a = ( + _joint_logp( + (x,), + rvs_to_values={x: x_val_transf}, + rvs_to_transforms={x: transform}, + rvs_to_total_sizes={}, + jacobian=False, + )[0] + .sum() + .eval({x_val_transf: test_array_transf}) + ) + b = ( + _joint_logp( + (x,), + rvs_to_values={x: x_val_untransf}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + )[0] + .sum() + .eval({x_val_untransf: test_array_untransf}) ) # Hack to get relative tolerance close_to(a, b, np.abs(0.5 * (a + b) * tol)) @@ -544,7 +578,7 @@ def test_triangular_transform(): with pm.Model() as m: x = pm.Triangular("x", lower=0, c=1, upper=2) - transform = x.tag.value_var.tag.transform + transform = m.rvs_to_transforms[x] assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) diff --git a/pymc/tests/distributions/util.py b/pymc/tests/distributions/util.py index 3f65350694..0a501da4e8 100644 --- a/pymc/tests/distributions/util.py +++ b/pymc/tests/distributions/util.py @@ -20,7 +20,7 @@ from pymc.aesaraf import compile_pymc, floatX, intX from pymc.distributions import logcdf, logp -from pymc.distributions.logprob import joint_logp +from pymc.distributions.logprob import _joint_logp from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.tests.helpers import SeededTest, select_by_precision @@ -288,9 +288,9 @@ def _model_input_dict(model, param_vars, pt): for k, v in pt.items(): rv_var = model.named_vars.get(k) nv = param_vars.get(k, rv_var) - nv = getattr(nv.tag, "value_var", nv) + nv = model.rvs_to_values.get(nv, nv) - transform = getattr(nv.tag, "transform", None) + transform = model.rvs_to_transforms.get(rv_var, None) if transform: # todo: the compiled graph behind this should be cached and # reused (if it isn't already). @@ -582,7 +582,16 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): assert np.allclose(moment, expected) if check_finite_logp: - logp_moment = joint_logp(model["x"], at.constant(moment), transformed=False).eval() + logp_moment = ( + _joint_logp( + (model["x"],), + rvs_to_values={model["x"]: at.constant(moment)}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + )[0] + .sum() + .eval() + ) assert np.isfinite(logp_moment) diff --git a/pymc/tests/helpers.py b/pymc/tests/helpers.py index e291ad2faa..85ab0f68cb 100644 --- a/pymc/tests/helpers.py +++ b/pymc/tests/helpers.py @@ -24,8 +24,10 @@ import numpy.random as nr from aesara.gradient import verify_grad as at_verify_grad +from aesara.graph import ancestors from aesara.graph.rewriting.basic import in2out from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream +from aesara.tensor.random.op import RandomVariable import pymc as pm @@ -218,3 +220,8 @@ def continuous_steps(self, step, step_kwargs): assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( step([c1, c2], **step_kwargs).vars ) + + +def assert_no_rvs(var): + assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner) + return var diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index f579df7c69..f627d932fa 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -39,6 +39,7 @@ convert_observed_data, extract_obs_data, replace_rng_nodes, + replace_rvs_by_values, reseed_rngs, rvs_to_value_vars, walk_model, @@ -47,6 +48,7 @@ from pymc.distributions.distribution import SymbolicRandomVariable from pymc.distributions.transforms import Interval from pymc.exceptions import NotConstantValueError +from pymc.tests.helpers import assert_no_rvs from pymc.vartypes import int_types @@ -98,7 +100,7 @@ def test_make_shared_replacements(self): ) assert ( test_model.test1.broadcastable - == replacement[test_model.test1.tag.value_var].broadcastable + == replacement[test_model.rvs_to_values[test_model.test1]].broadcastable ) def test_metropolis_sampling(self): @@ -305,110 +307,6 @@ def test_walk_model(): assert e in res -@pytest.mark.parametrize("symbolic_rv", (False, True)) -@pytest.mark.parametrize("apply_transforms", (True, False)) -def test_rvs_to_value_vars(symbolic_rv, apply_transforms): - - # Interval transform between last two arguments - interval = Interval(bounds_fn=lambda *args: (args[-2], args[-1])) - - with pm.Model() as m: - a = pm.Uniform("a", 0.0, 1.0) - if symbolic_rv: - raw_b = pm.Uniform.dist(0, a + 1.0) - b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) - # If not True, another distribution has to be used - assert isinstance(b.owner.op, SymbolicRandomVariable) - else: - b = pm.Uniform("b", 0, a + 1.0, transform=interval) - c = pm.Normal("c") - d = at.log(c + b) + 2.0 - - a_value_var = m.rvs_to_values[a] - assert a_value_var.tag.transform - - b_value_var = m.rvs_to_values[b] - c_value_var = m.rvs_to_values[c] - - (res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms) - - assert res.owner.op == at.add - log_output = res.owner.inputs[0] - assert log_output.owner.op == at.log - log_add_output = res.owner.inputs[0].owner.inputs[0] - assert log_add_output.owner.op == at.add - c_output = log_add_output.owner.inputs[0] - - # We make sure that the random variables were replaced - # with their value variables - assert c_output == c_value_var - b_output = log_add_output.owner.inputs[1] - # When transforms are applied, the input is the back-transformation of the value_var, - # otherwise it is the value_var itself - if apply_transforms: - assert b_output != b_value_var - else: - assert b_output == b_value_var - - res_ancestors = list(walk_model((res,))) - res_rv_ancestors = [ - v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) - ] - - # There shouldn't be any `RandomVariable`s in the resulting graph - assert len(res_rv_ancestors) == 0 - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors - # When transforms are used, `d` depends on `a` through the back-transformation of - # `b`, otherwise there is no direct connection between `d` and `a` - if apply_transforms: - assert a_value_var in res_ancestors - else: - assert a_value_var not in res_ancestors - - -def test_rvs_to_value_vars_nested(): - # Test that calling rvs_to_value_vars in models with nested transformations - # does not change the original rvs in place. See issue #5172 - with pm.Model() as m: - one = pm.LogNormal("one", mu=0) - two = pm.LogNormal("two", mu=at.log(one)) - - # We add potentials or deterministics that are not in topological order - pm.Potential("two_pot", two) - pm.Potential("one_pot", one) - - before = aesara.clone_replace(m.free_RVs) - - # This call would change the model free_RVs in place in #5172 - res = rvs_to_value_vars(m.potentials, apply_transforms=True) - - after = aesara.clone_replace(m.free_RVs) - - assert equal_computations(before, after) - - -def test_rvs_to_value_vars_unvalued_rv(): - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal.dist(x) - z = pm.Normal("z", y) - out = z + y - - x_value = m.rvs_to_values[x] - z_value = m.rvs_to_values[z] - - (res,) = rvs_to_value_vars((out,)) - - assert res.owner.op == at.add - assert res.owner.inputs[0] is z_value - res_y = res.owner.inputs[1] - # Graph should have be cloned, and therefore y and res_y should have different ids - assert res_y is not y - assert res_y.owner.op == at.random.normal - assert res_y.owner.inputs[3] is x_value - - class TestCompilePyMC: def test_check_bounds_flag(self): """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" @@ -633,3 +531,196 @@ def test_constant_fold_raises(): res = constant_fold((y, y.shape), raise_not_constant=False) assert tuple(res[1].eval()) == (5,) + + +class TestReplaceRVsByValues: + @pytest.mark.parametrize("symbolic_rv", (False, True)) + @pytest.mark.parametrize("apply_transforms", (True, False)) + @pytest.mark.parametrize("test_deprecated_fn", (True, False)) + def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn): + + # Interval transform between last two arguments + interval = ( + Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None + ) + + with pm.Model() as m: + a = pm.Uniform("a", 0.0, 1.0) + if symbolic_rv: + raw_b = pm.Uniform.dist(0, a + 1.0) + b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) + # If not True, another distribution has to be used + assert isinstance(b.owner.op, SymbolicRandomVariable) + else: + b = pm.Uniform("b", 0, a + 1.0, transform=interval) + c = pm.Normal("c") + d = at.log(c + b) + 2.0 + + a_value_var = m.rvs_to_values[a] + assert m.rvs_to_transforms[a] is not None + + b_value_var = m.rvs_to_values[b] + c_value_var = m.rvs_to_values[c] + + if test_deprecated_fn: + with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): + (res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms) + else: + (res,) = replace_rvs_by_values( + (d,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) + + assert res.owner.op == at.add + log_output = res.owner.inputs[0] + assert log_output.owner.op == at.log + log_add_output = res.owner.inputs[0].owner.inputs[0] + assert log_add_output.owner.op == at.add + c_output = log_add_output.owner.inputs[0] + + # We make sure that the random variables were replaced + # with their value variables + assert c_output == c_value_var + b_output = log_add_output.owner.inputs[1] + # When transforms are applied, the input is the back-transformation of the value_var, + # otherwise it is the value_var itself + if apply_transforms: + assert b_output != b_value_var + else: + assert b_output == b_value_var + + res_ancestors = list(walk_model((res,))) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert b_value_var in res_ancestors + assert c_value_var in res_ancestors + # When transforms are used, `d` depends on `a` through the back-transformation of + # `b`, otherwise there is no direct connection between `d` and `a` + if apply_transforms: + assert a_value_var in res_ancestors + else: + assert a_value_var not in res_ancestors + + @pytest.mark.parametrize("test_deprecated_fn", (True, False)) + def test_unvalued_rv(self, test_deprecated_fn): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x) + z = pm.Normal("z", y) + out = z + y + + x_value = m.rvs_to_values[x] + z_value = m.rvs_to_values[z] + + if test_deprecated_fn: + with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): + (res,) = rvs_to_value_vars((out,)) + else: + (res,) = replace_rvs_by_values( + (out,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) + + assert res.owner.op == at.add + assert res.owner.inputs[0] is z_value + res_y = res.owner.inputs[1] + # Graph should have be cloned, and therefore y and res_y should have different ids + assert res_y is not y + assert res_y.owner.op == at.random.normal + assert res_y.owner.inputs[3] is x_value + + @pytest.mark.parametrize("test_deprecated_fn", (True, False)) + def test_no_change_inplace(self, test_deprecated_fn): + # Test that calling rvs_to_value_vars in models with nested transformations + # does not change the original rvs in place. See issue #5172 + with pm.Model() as m: + one = pm.LogNormal("one", mu=0) + two = pm.LogNormal("two", mu=at.log(one)) + + # We add potentials or deterministics that are not in topological order + pm.Potential("two_pot", two) + pm.Potential("one_pot", one) + + before = aesara.clone_replace(m.free_RVs) + + # This call would change the model free_RVs in place in #5172 + if test_deprecated_fn: + with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): + rvs_to_value_vars(m.potentials) + else: + replace_rvs_by_values( + m.potentials, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) + + after = aesara.clone_replace(m.free_RVs) + assert equal_computations(before, after) + + @pytest.mark.parametrize("test_deprecated_fn", (True, False)) + @pytest.mark.parametrize("reversed", (False, True)) + def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn): + # Test that nested transformed variables, whose transformed values depend on other + # RVs are properly replaced + with pm.Model() as m: + transform = pm.distributions.transforms.Interval( + bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) + ) + x = pm.Uniform("x", lower=0, upper=1, transform=transform) + y = pm.Uniform("y", lower=0, upper=x, transform=transform) + z = pm.Uniform("z", lower=0, upper=y, transform=transform) + w = pm.Uniform("w", lower=0, upper=z, transform=transform) + + rvs = [x, y, z, w] + if reversed: + rvs = rvs[::-1] + + if test_deprecated_fn: + with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): + transform_values = rvs_to_value_vars(rvs) + else: + transform_values = replace_rvs_by_values( + rvs, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) + + for transform_value in transform_values: + assert_no_rvs(transform_value) + + if reversed: + transform_values = transform_values[::-1] + transform_values_fn = m.compile_fn(transform_values, point_fn=False) + + x_interval_test_value = np.random.rand() + y_interval_test_value = np.random.rand() + z_interval_test_value = np.random.rand() + w_interval_test_value = np.random.rand() + + # The 3 Nones correspond to unused rng, dtype and size arguments + expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() + expected_y = transform.backward( + y_interval_test_value, None, None, None, 0, expected_x + ).eval() + expected_z = transform.backward( + z_interval_test_value, None, None, None, 0, expected_y + ).eval() + expected_w = transform.backward( + w_interval_test_value, None, None, None, 0, expected_z + ).eval() + + np.testing.assert_allclose( + transform_values_fn( + x_interval__=x_interval_test_value, + y_interval__=y_interval_test_value, + z_interval__=z_interval_test_value, + w_interval__=w_interval_test_value, + ), + [expected_x, expected_y, expected_z, expected_w], + ) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index c294ef3c3c..6983f3e342 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -621,12 +621,12 @@ def test_gradient_with_scaling(self): genvar = pm.generator(gen1()) m = pm.Normal("m") pm.Normal("n", observed=genvar, total_size=1000) - grad1 = aesara.function([m.tag.value_var], at.grad(model1.logp(), m.tag.value_var)) + grad1 = model1.compile_fn(model1.dlogp(vars=m), point_fn=False) with pm.Model() as model2: m = pm.Normal("m") shavar = aesara.shared(np.ones((1000, 100))) pm.Normal("n", observed=shavar) - grad2 = aesara.function([m.tag.value_var], at.grad(model2.logp(), m.tag.value_var)) + grad2 = model2.compile_fn(model2.dlogp(vars=m), point_fn=False) for i in range(10): shavar.set_value(np.ones((100, 100)) * i) @@ -709,11 +709,11 @@ def test_mixed2(self): def test_free_rv(self): with pm.Model() as model4: pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[2, 2]) - p4 = aesara.function([], model4.logp()) + p4 = model4.compile_fn(model4.logp(), point_fn=False) with pm.Model() as model5: n = pm.Normal("n", total_size=[2, Ellipsis, 2], size=(2, 2)) - p5 = aesara.function([n.tag.value_var], model5.logp()) + p5 = model5.compile_fn(model5.logp(), point_fn=False) assert p4() == p5(pm.floatX([[1]])) assert p4() == p5(pm.floatX([[1, 1], [1, 1]])) diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py index e170360b34..eeef25ec65 100644 --- a/pymc/tests/test_initial_point.py +++ b/pymc/tests/test_initial_point.py @@ -25,12 +25,12 @@ from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain -def transform_fwd(rv, expected_untransformed): - return rv.tag.value_var.tag.transform.forward(expected_untransformed, *rv.owner.inputs).eval() +def transform_fwd(rv, expected_untransformed, model): + return model.rvs_to_transforms[rv].forward(expected_untransformed, *rv.owner.inputs).eval() -def transform_back(rv, transformed) -> np.ndarray: - return rv.tag.value_var.tag.transform.backward(transformed, *rv.owner.inputs).eval() +def transform_back(rv, transformed, model) -> np.ndarray: + return model.rvs_to_transforms[rv].backward(transformed, *rv.owner.inputs).eval() class TestInitvalAssignment: @@ -48,7 +48,7 @@ def test_new_warnings(self): with pytest.warns(FutureWarning, match="`testval` argument is deprecated"): rv = pm.Uniform("u", 0, 1, testval=0.75) initial_point = pmodel.initial_point(seed=0) - assert initial_point["u_interval__"] == transform_fwd(rv, 0.75) + assert initial_point["u_interval__"] == transform_fwd(rv, 0.75, model=pmodel) assert not hasattr(rv.tag, "test_value") pass @@ -163,7 +163,7 @@ def test_adds_jitter(self): # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default # so the transformed initial value with jitter will be zero plus a jitter between [-1, 1]. b_transformed = iv["B_log__"] - b_untransformed = transform_back(B, b_transformed) + b_untransformed = transform_back(B, b_transformed, model=pmodel) assert b_transformed != 0 assert -1 < b_transformed < 1 # C is centered on 0 + untransformed initval of B diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 08c5d4cbee..e2c2dada95 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -28,6 +28,7 @@ import scipy.sparse as sps import scipy.stats as st +from aeppl.transforms import IntervalTransform from aesara.graph import graph_inputs from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable @@ -39,10 +40,13 @@ from pymc import Deterministic, Potential from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.distributions import Normal, transforms +from pymc.distributions.logprob import _joint_logp +from pymc.distributions.transforms import log from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning from pymc.model import Point, ValueGradFunction, modelcontext from pymc.tests.helpers import SeededTest from pymc.tests.models import simple_model +from pymc.util import _FutureWarningValidatingScratchpad class NewModel(pm.Model): @@ -495,10 +499,22 @@ def test_model_value_vars(): def test_model_var_maps(): with pm.Model() as model: a = pm.Uniform("a") - x = pm.Normal("x", a) + x = pm.Normal("x", a, total_size=5) + + assert set(model.rvs_to_values.keys()) == {a, x} + a_value = model.rvs_to_values[a] + x_value = model.rvs_to_values[x] + assert a_value.owner is None + assert x_value.owner is None + assert model.values_to_rvs == {a_value: a, x_value: x} + + assert set(model.rvs_to_transforms.keys()) == {a, x} + assert isinstance(model.rvs_to_transforms[a], IntervalTransform) + assert model.rvs_to_transforms[x] is None - assert model.rvs_to_values == {a: a.tag.value_var, x: x.tag.value_var} - assert model.values_to_rvs == {a.tag.value_var: a, x.tag.value_var: x} + assert set(model.rvs_to_total_sizes.keys()) == {a, x} + assert model.rvs_to_total_sizes[a] is None + assert model.rvs_to_total_sizes[x] == 5 def test_make_obs_var(): @@ -531,12 +547,12 @@ def test_make_obs_var(): dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None) assert dense_output == fake_distribution - assert isinstance(dense_output.tag.observations, TensorConstant) + assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant) del fake_model.named_vars[fake_distribution.name] sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None) assert sparse_output == fake_distribution - assert sparse.basic._is_sparse_variable(sparse_output.tag.observations) + assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output]) del fake_model.named_vars[fake_distribution.name] # Here the RandomVariable is split into observed/imputed and a Deterministic is returned @@ -567,8 +583,7 @@ def test_initial_point(): with pytest.warns(FutureWarning), model: b = pm.Uniform("b", testval=b_initval) - b_value_var = model.rvs_to_values[b] - b_initval_trans = b_value_var.tag.transform.forward(b_initval, *b.owner.inputs).eval() + b_initval_trans = model.rvs_to_transforms[b].forward(b_initval, *b.owner.inputs).eval() y_initval = np.array(-2.4, dtype=aesara.config.floatX) @@ -635,12 +650,8 @@ def test_valid_start_point(self): b = pm.Uniform("b", lower=2.0, upper=3.0) start = { - "a_interval__": model.rvs_to_values[a] - .tag.transform.forward(0.3, *a.owner.inputs) - .eval(), - "b_interval__": model.rvs_to_values[b] - .tag.transform.forward(2.1, *b.owner.inputs) - .eval(), + "a_interval__": model.rvs_to_transforms[a].forward(0.3, *a.owner.inputs).eval(), + "b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(), } model.check_start_vals(start) @@ -651,9 +662,7 @@ def test_invalid_start_point(self): start = { "a_interval__": np.nan, - "b_interval__": model.rvs_to_values[b] - .tag.transform.forward(2.1, *b.owner.inputs) - .eval(), + "b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(), } with pytest.raises(pm.exceptions.SamplingError): model.check_start_vals(start) @@ -664,12 +673,8 @@ def test_invalid_variable_name(self): b = pm.Uniform("b", lower=2.0, upper=3.0) start = { - "a_interval__": model.rvs_to_values[a] - .tag.transform.forward(0.3, *a.owner.inputs) - .eval(), - "b_interval__": model.rvs_to_values[b] - .tag.transform.forward(2.1, *b.owner.inputs) - .eval(), + "a_interval__": model.rvs_to_transforms[a].forward(0.3, *a.owner.inputs).eval(), + "b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(), "c": 1.0, } with pytest.raises(KeyError): @@ -1207,9 +1212,7 @@ def test_interval_missing_observations(self): assert "theta1_observed" in model.named_vars assert "theta1_missing_interval__" in model.named_vars - assert not hasattr( - model.rvs_to_values[model.named_vars["theta1_observed"]].tag, "transform" - ) + assert model.rvs_to_transforms[model.named_vars["theta1_observed"]] is None prior_trace = pm.sample_prior_predictive(return_inferencedata=False) @@ -1348,7 +1351,7 @@ def test_missing_symmetric(self): This would fail in a previous implementation because the two variables would be equivalent and one of them would be discarded during MergeOptimization while - buling the logp graph + building the logp graph """ with pm.Model() as m: with pytest.warns(ImputationWarning): @@ -1360,8 +1363,13 @@ def test_missing_symmetric(self): x_unobs_rv = m["x_missing"] x_unobs_vv = m.rvs_to_values[x_unobs_rv] - logp = pm.joint_logp([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}) - logp_inputs = list(graph_inputs([logp])) + logp = _joint_logp( + [x_obs_rv, x_unobs_rv], + rvs_to_values={x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}, + rvs_to_transforms={}, + rvs_to_total_sizes={}, + ) + logp_inputs = list(graph_inputs(logp)) assert x_obs_vv in logp_inputs assert x_unobs_vv in logp_inputs @@ -1400,3 +1408,59 @@ def test_deterministic(self): assert np.all( np.isclose(model.compile_logp(sum=False)({}), st.norm().logpdf(data_values)) ) + + +def test_tag_future_warning_model(): + # Test no unexpected warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + + model = pm.Model() + + x = at.random.normal() + x.tag.something_else = "5" + x.tag.test_value = 0 + assert not isinstance(x.tag, _FutureWarningValidatingScratchpad) + + # Test that model changes the tag type, but copies exsiting contents + x = model.register_rv(x, name="x", transform=log) + assert isinstance(x.tag, _FutureWarningValidatingScratchpad) + assert x.tag.something_else == "5" + assert x.tag.test_value == 0 + + # Test expected warnings + with pytest.warns(FutureWarning, match="model.rvs_to_values"): + x_value = x.tag.value_var + + assert isinstance(x_value.tag, _FutureWarningValidatingScratchpad) + with pytest.warns(FutureWarning, match="model.rvs_to_transforms"): + transform = x_value.tag.transform + assert transform is log + + with pytest.raises(AttributeError): + x.tag.observations + + with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): + total_size = x.tag.total_size + assert total_size is None + + # Cloning a node will keep the same tag type and contents + y = x.owner.clone().default_output() + assert y is not x + assert y.tag is not x.tag + assert isinstance(y.tag, _FutureWarningValidatingScratchpad) + y = model.register_rv(y, name="y", data=5) + assert isinstance(y.tag, _FutureWarningValidatingScratchpad) + + # Test expected warnings + with pytest.warns(FutureWarning, match="model.rvs_to_values"): + y_value = y.tag.value_var + with pytest.warns(FutureWarning, match="model.rvs_to_values"): + y_obs = y.tag.observations + assert y_value is y_obs + assert y_value.eval() == 5 + + assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) + with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): + total_size = y.tag.total_size + assert total_size is None diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index f9050e3dcf..7bce1f5ecc 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -104,10 +104,8 @@ def find_MAP( vars = get_value_vars_from_user_vars(vars, model) except ValueError as exc: # Accomodate case where user passed non-pure RV nodes - vars = pm.inputvars(pm.aesaraf.rvs_to_value_vars(vars)) + vars = pm.inputvars(model.replace_rvs_by_values(vars)) if vars: - # Make sure they belong to current model again... - vars = get_value_vars_from_user_vars(vars, model) warnings.warn( "Intermediate variables (such as Deterministic or Potential) were passed. " "find_MAP will optimize the underlying free_RVs instead.", diff --git a/pymc/util.py b/pymc/util.py index 64b3a376fa..010ab20baf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast @@ -23,6 +24,7 @@ from aesara import Variable from aesara.compile import SharedVariable +from aesara.graph.utils import ValidatingScratchpad from cachetools import LRUCache, cachedmethod @@ -481,3 +483,32 @@ def get_value_vars_from_user_vars( ) return value_vars + + +class _FutureWarningValidatingScratchpad(ValidatingScratchpad): + def __getattribute__(self, name): + for deprecated_names, alternative in ( + (("value_var", "observations"), "model.rvs_to_values[rv]"), + (("transform",), "model.rvs_to_transforms[rv]"), + (("total_size",), "model.rvs_to_total_sizes[rv]"), + ): + if name in deprecated_names: + try: + super().__getattribute__(name) + except AttributeError: + pass + else: + warnings.warn( + f"The tag attribute {name} is deprecated. Use {alternative} instead", + FutureWarning, + ) + return super().__getattribute__(name) + + +def _add_future_warning_tag(var) -> None: + old_tag = var.tag + if not isinstance(old_tag, _FutureWarningValidatingScratchpad): + new_tag = _FutureWarningValidatingScratchpad("test_value", var.type.filter) + for k, v in old_tag.__dict__.items(): + new_tag.__dict__.setdefault(k, v) + var.tag = new_tag diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index be9e430462..1b3c331a8f 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -64,11 +64,11 @@ find_rng_nodes, identity, reseed_rngs, - rvs_to_value_vars, ) from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection +from pymc.distributions.logprob import _get_scaling from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext from pymc.util import ( @@ -1039,7 +1039,14 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) @node_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`""" - t = self.to_flat_input(at.max([v.tag.scaling for v in self.group])) + t = self.to_flat_input( + at.max( + [ + _get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim) + for v in self.group + ] + ) + ) t = self.symbolic_single_sample(t) return pm.floatX(t) @@ -1171,7 +1178,14 @@ def symbolic_normalizing_constant(self): """ t = at.max( self.collect("symbolic_normalizing_constant") - + [var.tag.scaling for var in self.model.observed_RVs] + + [ + _get_scaling( + self.model.rvs_to_total_sizes.get(obs, None), + obs.shape, + obs.ndim, + ) + for obs in self.model.observed_RVs + ] ) t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype)) return pm.floatX(t) @@ -1391,7 +1405,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No node = aesara.clone_replace(node, more_replacements) if not isinstance(node, (list, tuple)): node = [node] - node = rvs_to_value_vars(node) + node = self.model.replace_rvs_by_values(node) if not isinstance(node_in, (list, tuple)): node = node[0] if size is None: