diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index f9ce65ff82..4aa5981f03 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings - from typing import ( Callable, Dict, @@ -147,32 +145,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar return at.as_tensor_variable(df.to_numpy(), *args, **kwargs) -def extract_rv_and_value_vars( - var: TensorVariable, -) -> Tuple[TensorVariable, TensorVariable]: - """Return a random variable and it's observations or value variable, or ``None``. - - Parameters - ========== - var - A variable corresponding to a ``RandomVariable``. - - Returns - ======= - The first value in the tuple is the ``RandomVariable``, and the second is the - measure/log-likelihood value variable that corresponds with the latter. - - """ - if not var.owner: - return None, None - - if isinstance(var.owner.op, RandomVariable): - rv_value = getattr(var.tag, "observations", getattr(var.tag, "value_var", None)) - return var, rv_value - - return None, None - - def extract_obs_data(x: TensorVariable) -> np.ndarray: """Extract data from observed symbolic variables. @@ -200,20 +172,15 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: def walk_model( graphs: Iterable[TensorVariable], - walk_past_rvs: bool = False, stop_at_vars: Optional[Set[TensorVariable]] = None, expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [], ) -> Generator[TensorVariable, None, None]: """Walk model graphs and yield their nodes. - By default, these walks will not go past ``RandomVariable`` nodes. - Parameters ========== graphs The graphs to walk. - walk_past_rvs - If ``True``, the walk will not terminate at ``RandomVariable``s. stop_at_vars A list of variables at which the walk will terminate. expand_fn @@ -225,16 +192,12 @@ def walk_model( def expand(var): new_vars = expand_fn(var) - if ( - var.owner - and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable)) - and (var not in stop_at_vars) - ): + if var.owner and var not in stop_at_vars: new_vars.extend(reversed(var.owner.inputs)) return new_vars - yield from walk(graphs, expand, False) + yield from walk(graphs, expand, bfs=False) def replace_rvs_in_graphs( @@ -263,7 +226,11 @@ def replace_rvs_in_graphs( def expand_replace(var): new_nodes = [] - if var.owner and isinstance(var.owner.op, RandomVariable): + if var.owner: + # Call replacement_fn to update replacements dict inplace and, optionally, + # specify new nodes that should also be walked for replacements. This + # includes `value` variables that are not simple input variables, and may + # contain other `random` variables in their graphs (e.g., IntervalTransform) new_nodes.extend(replacement_fn(var, replacements)) return new_nodes @@ -290,10 +257,10 @@ def expand_replace(var): def rvs_to_value_vars( graphs: Iterable[TensorVariable], - apply_transforms: bool = False, + apply_transforms: bool = True, initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, **kwargs, -) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: +) -> TensorVariable: """Clone and replace random variables in graphs with their value variables. This will *not* recompute test values in the resulting graphs. @@ -309,38 +276,30 @@ def rvs_to_value_vars( """ - # Avoid circular dependency - from pymc.distributions import NoDistribution - - def transform_replacements(var, replacements): - rv_var, rv_value_var = extract_rv_and_value_vars(var) - - if rv_value_var is None: - # If RandomVariable does not have a value_var and corresponds to - # a NoDistribution, we allow further replacements in upstream graph - if isinstance(rv_var.owner.op, NoDistribution): - return rv_var.owner.inputs + def populate_replacements( + random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable] + ) -> List[TensorVariable]: + # Populate replacements dict with {rv: value} pairs indicating which graph + # RVs should be replaced by what value variables. - else: - warnings.warn( - f"No value variable found for {rv_var}; " - "the random variable will not be replaced." - ) - return [] + value_var = getattr( + random_var.tag, "observations", getattr(random_var.tag, "value_var", None) + ) - transform = getattr(rv_value_var.tag, "transform", None) + # No value variable to replace RV with + if value_var is None: + return [] - if transform is None or not apply_transforms: - replacements[var] = rv_value_var - # In case the value variable is itself a graph, we walk it for - # potential replacements - return [rv_value_var] + transform = getattr(value_var.tag, "transform", None) + if transform is not None and apply_transforms: + # We want to replace uses of the RV by the back-transformation of its value + value_var = transform.backward(value_var, *random_var.owner.inputs) - trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs) - replacements[var] = trans_rv_value + replacements[random_var] = value_var - # Walk the transformed variable and make replacements - return [trans_rv_value] + # Also walk the graph of the value variable to make any additional replacements + # if that is not a simple input variable + return [value_var] # Clone original graphs inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] @@ -352,7 +311,14 @@ def transform_replacements(var, replacements): equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items() } - return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs) + graphs, _ = replace_rvs_in_graphs( + graphs, + replacement_fn=populate_replacements, + initial_replacements=initial_replacements, + **kwargs, + ) + + return graphs def inputvars(a): diff --git a/pymc/gp/util.py b/pymc/gp/util.py index f82ee7f2bf..a25a393e77 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -57,7 +57,7 @@ def replace_with_values(vars_needed, replacements=None, model=None): model = modelcontext(model) inputs, input_names = [], [] - for rv in walk_model(vars_needed, walk_past_rvs=True): + for rv in walk_model(vars_needed): if rv in model.named_vars.values() and not isinstance(rv, SharedVariable): inputs.append(rv) input_names.append(rv.name) diff --git a/pymc/model.py b/pymc/model.py index 037385d534..cd025fece3 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -761,7 +761,7 @@ def logp( # Replace random variables by their value variables in potential terms potential_logps = [] if potentials: - potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True) + potential_logps = rvs_to_value_vars(potentials) logp_factors = [None] * len(varlist) for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)): @@ -935,7 +935,7 @@ def potentiallogp(self) -> Variable: """Aesara scalar of log-probability of the Potential terms""" # Convert random variables in Potential expression into their log-likelihood # inputs and apply their transforms, if any - potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) + potentials = rvs_to_value_vars(self.potentials) if potentials: return at.sum([at.sum(factor) for factor in potentials]) else: @@ -976,10 +976,10 @@ def unobserved_value_vars(self): vars.append(value_var) # Remove rvs from untransformed values graph - untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True) + untransformed_vars = rvs_to_value_vars(untransformed_vars) # Remove rvs from deterministics graph - deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True) + deterministics = rvs_to_value_vars(self.deterministics) return vars + untransformed_vars + deterministics diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index e036b0f667..68909c018c 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -583,7 +583,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): if isinstance(distr, CategoricalRV): k_graph = rv_var.owner.inputs[3].shape[-1] - (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True) + (k_graph,) = rvs_to_value_vars((k_graph,)) k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")( initial_point ) diff --git a/pymc/tests/distributions/test_logprob.py b/pymc/tests/distributions/test_logprob.py index 26bc0c4ff7..4212b4baa7 100644 --- a/pymc/tests/distributions/test_logprob.py +++ b/pymc/tests/distributions/test_logprob.py @@ -129,7 +129,7 @@ def test_joint_logp_basic(): with pytest.warns(FutureWarning): b_logpt = joint_logpt(b, b_value_var, sum=False) - res_ancestors = list(walk_model(b_logp, walk_past_rvs=True)) + res_ancestors = list(walk_model(b_logp)) res_rv_ancestors = [ v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) ] diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index bd618372b2..8fc641d01c 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -44,6 +44,7 @@ ) from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable +from pymc.distributions.transforms import Interval from pymc.vartypes import int_types @@ -271,33 +272,53 @@ def test_pandas_to_array_pandas_index(): def test_walk_model(): - d = at.vector("d") - b = at.vector("b") - c = uniform(0.0, d) + a = at.vector("a") + b = uniform(0.0, a, name="b") + c = at.log(b) c.name = "c" - e = at.log(c) - a = normal(e, b) - a.name = "a" + d = at.vector("d") + e = normal(c, d, name="e") + + test_graph = at.exp(e + 1) - test_graph = at.exp(a + 1) res = list(walk_model((test_graph,))) assert a in res - assert c not in res + assert b in res + assert c in res + assert d in res + assert e in res - res = list(walk_model((test_graph,), walk_past_rvs=True)) - assert a in res + res = list(walk_model((test_graph,), stop_at_vars={c})) + assert a not in res + assert b not in res assert c in res + assert d in res + assert e in res - res = list(walk_model((test_graph,), walk_past_rvs=True, stop_at_vars={e})) - assert a in res - assert c not in res + res = list(walk_model((test_graph,), stop_at_vars={b})) + assert a not in res + assert b in res + assert c in res + assert d in res + assert e in res -def test_rvs_to_value_vars(): +@pytest.mark.parametrize("symbolic_rv", (False, True)) +@pytest.mark.parametrize("apply_transforms", (True, False)) +def test_rvs_to_value_vars(symbolic_rv, apply_transforms): + + # Interval transform between last two arguments + interval = Interval(bounds_fn=lambda *args: (args[-2], args[-1])) with pm.Model() as m: a = pm.Uniform("a", 0.0, 1.0) - b = pm.Uniform("b", 0, a + 1.0) + if symbolic_rv: + raw_b = pm.Uniform.dist(0, a + 1.0) + b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) + # If not True, another distribution has to be used + assert isinstance(b.owner.op, SymbolicRandomVariable) + else: + b = pm.Uniform("b", 0, a + 1.0, transform=interval) c = pm.Normal("c") d = at.log(c + b) + 2.0 @@ -307,7 +328,7 @@ def test_rvs_to_value_vars(): b_value_var = m.rvs_to_values[b] c_value_var = m.rvs_to_values[c] - (res,), replaced = rvs_to_value_vars((d,)) + (res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms) assert res.owner.op == at.add log_output = res.owner.inputs[0] @@ -320,9 +341,14 @@ def test_rvs_to_value_vars(): # with their value variables assert c_output == c_value_var b_output = log_add_output.owner.inputs[1] - assert b_output == b_value_var + # When transforms are applied, the input is the back-transformation of the value_var, + # otherwise it is the value_var itself + if apply_transforms: + assert b_output != b_value_var + else: + assert b_output == b_value_var - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) + res_ancestors = list(walk_model((res,))) res_rv_ancestors = [ v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) ] @@ -331,19 +357,12 @@ def test_rvs_to_value_vars(): assert len(res_rv_ancestors) == 0 assert b_value_var in res_ancestors assert c_value_var in res_ancestors - assert a_value_var not in res_ancestors - - (res,), replaced = rvs_to_value_vars((d,), apply_transforms=True) - - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) - res_rv_ancestors = [ - v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) - ] - - assert len(res_rv_ancestors) == 0 - assert a_value_var in res_ancestors - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors + # When transforms are used, `d` depends on `a` through the back-transformation of + # `b`, otherwise there is no direct connection between `d` and `a` + if apply_transforms: + assert a_value_var in res_ancestors + else: + assert a_value_var not in res_ancestors def test_rvs_to_value_vars_nested(): @@ -360,13 +379,34 @@ def test_rvs_to_value_vars_nested(): before = aesara.clone_replace(m.free_RVs) # This call would change the model free_RVs in place in #5172 - res, _ = rvs_to_value_vars(m.potentials, apply_transforms=True) + res = rvs_to_value_vars(m.potentials, apply_transforms=True) after = aesara.clone_replace(m.free_RVs) assert equal_computations(before, after) +def test_rvs_to_value_vars_unvalued_rv(): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x) + z = pm.Normal("z", y) + out = z + y + + x_value = m.rvs_to_values[x] + z_value = m.rvs_to_values[z] + + (res,) = rvs_to_value_vars((out,)) + + assert res.owner.op == at.add + assert res.owner.inputs[0] is z_value + res_y = res.owner.inputs[1] + # Graph should have be cloned, and therefore y and res_y should have different ids + assert res_y is not y + assert res_y.owner.op == at.random.normal + assert res_y.owner.inputs[3] is x_value + + class TestCompilePyMC: def test_check_bounds_flag(self): """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 89b3de7269..a004c486d2 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -39,6 +39,7 @@ from pymc.aesaraf import compile_pymc from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray +from pymc.distributions import transforms from pymc.exceptions import IncorrectArgumentsError, SamplingError from pymc.sampling import _get_seeds_per_chain, compile_forward_sampling_function from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode @@ -408,9 +409,15 @@ def test_exceptions(self): with pytest.raises(NotImplementedError): xvars = [t["mu"] for t in trace] - def test_deterministic_of_unobserved(self): + @pytest.mark.parametrize("symbolic_rv", (False, True)) + def test_deterministic_of_unobserved(self, symbolic_rv): with pm.Model() as model: - x = pm.HalfNormal("x", 1) + if symbolic_rv: + x = pm.Censored( + "x", pm.HalfNormal.dist(1), lower=None, upper=10, transform=transforms.log + ) + else: + x = pm.HalfNormal("x", 1) y = pm.Deterministic("y", x + 100) idata = pm.sample( chains=1, @@ -421,10 +428,15 @@ def test_deterministic_of_unobserved(self): np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100) - def test_transform_with_rv_dependency(self): + @pytest.mark.parametrize("symbolic_rv", (False, True)) + def test_transform_with_rv_dependency(self, symbolic_rv): # Test that untransformed variables that depend on upstream variables are properly handled with pm.Model() as m: - x = pm.HalfNormal("x", observed=1) + if symbolic_rv: + x = pm.Censored("x", pm.HalfNormal.dist(1), lower=0, upper=1, observed=1) + else: + x = pm.HalfNormal("x", observed=1) + transform = pm.distributions.transforms.Interval( bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index e90c3b9be5..747b5582f4 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1386,7 +1386,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No node = aesara.clone_replace(node, more_replacements) if not isinstance(node, (list, tuple)): node = [node] - node, _ = rvs_to_value_vars(node, apply_transforms=True) + node = rvs_to_value_vars(node) if not isinstance(node_in, (list, tuple)): node = node[0] if size is None: