Skip to content

Do not rely on tag information for rv and logp conversions #6281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 82 additions & 16 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import (
Callable,
Dict,
Expand All @@ -31,6 +33,7 @@
import scipy.sparse as sps

from aeppl.logprob import CheckParameterValue
from aeppl.transforms import RVTransform
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
Expand Down Expand Up @@ -205,10 +208,9 @@ def expand(var):
yield from walk(graphs, expand, bfs=False)


def replace_rvs_in_graphs(
def _replace_rvs_in_graphs(
graphs: Iterable[TensorVariable],
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
**kwargs,
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
"""Replace random variables in graphs
Expand All @@ -226,8 +228,6 @@ def replace_rvs_in_graphs(
that were made.
"""
replacements = {}
if initial_replacements:
replacements.update(initial_replacements)

def expand_replace(var):
new_nodes = []
Expand All @@ -239,6 +239,7 @@ def expand_replace(var):
new_nodes.extend(replacement_fn(var, replacements))
return new_nodes

# This iteration populates the replacements
for var in walk_model(graphs, expand_fn=expand_replace, **kwargs):
pass

Expand All @@ -253,7 +254,15 @@ def expand_replace(var):
clone=False,
)

fg.replace_all(replacements.items(), import_missing=True)
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
toposort = fg.toposort()
sorted_replacements = sorted(
tuple(replacements.items()),
key=lambda pair: toposort.index(pair[0].owner),
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)

graphs = list(fg.outputs)

Expand All @@ -263,7 +272,6 @@ def expand_replace(var):
def rvs_to_value_vars(
graphs: Iterable[Variable],
apply_transforms: bool = True,
initial_replacements: Optional[Dict[Variable, Variable]] = None,
**kwargs,
) -> List[Variable]:
"""Clone and replace random variables in graphs with their value variables.
Expand All @@ -276,10 +284,11 @@ def rvs_to_value_vars(
The graphs in which to perform the replacements.
apply_transforms
If ``True``, apply each value variable's transform.
initial_replacements
A ``dict`` containing the initial replacements to be made.

"""
warnings.warn(
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
FutureWarning,
)

def populate_replacements(
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
Expand Down Expand Up @@ -311,15 +320,72 @@ def populate_replacements(
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

if initial_replacements:
initial_replacements = {
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
}

graphs, _ = replace_rvs_in_graphs(
graphs, _ = _replace_rvs_in_graphs(
graphs,
replacement_fn=populate_replacements,
initial_replacements=initial_replacements,
**kwargs,
)

return graphs


def replace_rvs_by_values(
graphs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
**kwargs,
) -> List[TensorVariable]:
"""Clone and replace random variables in graphs with their value variables.

This will *not* recompute test values in the resulting graphs.

Parameters
----------
graphs
The graphs in which to perform the replacements.
rvs_to_values
Mapping between the original graph RVs and respective value variables
rvs_to_transforms
Mapping between the original graph RVs and respective value transforms
"""

# Clone original graphs so that we don't modify variables in place
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

# Get needed mappings for equivalent cloned variables
equiv_rvs_to_values = {}
equiv_rvs_to_transforms = {}
for rv, value in rvs_to_values.items():
equiv_rv = equiv.get(rv, rv)
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]

def poulate_replacements(rv, replacements):
# Populate replacements dict with {rv: value} pairs indicating which graph
# RVs should be replaced by what value variables.

# No value variable to replace RV with
value = equiv_rvs_to_values.get(rv, None)
if value is None:
return []

transform = equiv_rvs_to_transforms.get(rv, None)
if transform is not None:
# We want to replace uses of the RV by the back-transformation of its value
value = transform.backward(value, *rv.owner.inputs)
value.name = rv.name

replacements[rv] = value
# Also walk the graph of the value variable to make any additional
# replacements if that is not a simple input variable
return [value]

graphs, _ = _replace_rvs_in_graphs(
graphs,
replacement_fn=poulate_replacements,
**kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
aux_obs = model.rvs_to_values.get(obs, None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
Expand Down Expand Up @@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):

if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
try:
obs_data = extract_obs_data(var.tag.observations)
obs_data = extract_obs_data(self.model.rvs_to_values[var])
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {var}")

Expand Down
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
logcdf,
logp,
joint_logp,
joint_logpt,
)

from pymc.distributions.bound import Bound
Expand Down Expand Up @@ -199,7 +198,6 @@
"Censored",
"CAR",
"PolyaGamma",
"joint_logpt",
"joint_logp",
"logp",
"logcdf",
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
shape_from_dims,
)
from pymc.printing import str_for_dist
from pymc.util import UNSET
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import string_types

__all__ = [
Expand Down Expand Up @@ -371,6 +371,7 @@ def dist(
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
_add_future_warning_tag(rv_out)
return rv_out


Expand Down
122 changes: 75 additions & 47 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,18 @@
from aeppl.logprob import logcdf as logcdf_aeppl
from aeppl.logprob import logprob as logp_aeppl
from aeppl.tensor import MeasurableJoin
from aeppl.transforms import TransformValuesRewrite
from aeppl.transforms import RVTransform, TransformValuesRewrite
from aesara import tensor as at
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import constant_fold, floatX

TOTAL_SIZE = Union[int, Sequence[int], None]

def _get_scaling(
total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int
) -> TensorVariable:

def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
"""
Gets scaling constant for logp.

Expand Down Expand Up @@ -112,22 +104,26 @@ def _get_scaling(
return at.as_tensor(coef, dtype=aesara.config.floatX)


subtensor_types = (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)

def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
from pymc.distributions.simulator import SimulatorRV

def joint_logpt(*args, **kwargs):
warnings.warn(
"joint_logpt has been deprecated. Use joint_logp instead.",
FutureWarning,
)
return joint_logp(*args, **kwargs)
unexpected_rv_nodes = [
node
for node in aesara.graph.ancestors(logp_terms)
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
)
]
if unexpected_rv_nodes:
raise ValueError(
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
"This can happen when DensityDist logp or Interval transform functions "
"reference nonlocal variables."
)


def joint_logp(
Expand Down Expand Up @@ -169,6 +165,10 @@ def joint_logp(
Sum the log-likelihood or return each term as a separate list item.

"""
warnings.warn(
"joint_logp has been deprecated, use model.logp instead",
FutureWarning,
)
# TODO: In future when we drop support for tag.value_var most of the following
# logic can be removed and logp can just be a wrapper function that calls aeppl's
# joint_logprob directly.
Expand Down Expand Up @@ -241,33 +241,15 @@ def joint_logp(
**kwargs,
)

# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
from pymc.distributions.simulator import SimulatorRV

unexpected_rv_nodes = [
node
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
)
]
if unexpected_rv_nodes:
raise ValueError(
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
"This can happen when DensityDist logp or Interval transform functions "
"reference nonlocal variables."
)

# aeppl returns the logp for every single value term we provided to it. This includes
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
logp_var_dict = {}
for value_var in rv_values.values():
logp_var_dict[value_var] = temp_logp_var_dict[value_var]

_check_no_rvs(list(logp_var_dict.values()))

if scaling:
for value_var in logp_var_dict.keys():
if value_var in rv_scalings:
Expand All @@ -281,6 +263,52 @@ def joint_logp(
return logp_var


def _joint_logp(
rvs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
jacobian: bool = True,
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
**kwargs,
) -> List[TensorVariable]:
"""Thin wrapper around aeppl.factorized_joint_logprob, extended with PyMC specific
concerns such as transforms, jacobian, and scaling"""

transform_rewrite = None
values_to_transforms = {
rvs_to_values[rv]: transform
for rv, transform in rvs_to_transforms.items()
if transform is not None
}
if values_to_transforms:
# There seems to be an incorrect type hint in TransformValuesRewrite
transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore

temp_logp_terms = factorized_joint_logprob(
rvs_to_values,
extra_rewrites=transform_rewrite,
use_jacobian=jacobian,
**kwargs,
)

# aeppl returns the logp for every single value term we provided to it. This includes
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
logp_terms = {}
for rv in rvs:
value_var = rvs_to_values[rv]
logp_term = temp_logp_terms[value_var]
total_size = rvs_to_total_sizes.get(rv, None)
if total_size is not None:
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
logp_term *= scaling
logp_terms[value_var] = logp_term

_check_no_rvs(list(logp_terms.values()))
return list(logp_terms.values())


def logp(rv: TensorVariable, value) -> TensorVariable:
"""Return the log-probability graph of a Random Variable"""

Expand Down
Loading