Skip to content

Implement check_icdf helper to test icdf implementations #6583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def polyagamma_cdf(*args, **kwargs):
from pymc.distributions import transforms
from pymc.distributions.dist_math import (
SplineWrapper,
check_icdf_parameters,
check_icdf_value,
check_parameters,
clipped_beta_rvs,
i0e,
Expand Down Expand Up @@ -532,7 +534,13 @@ def logcdf(value, mu, sigma):
)

def icdf(value, mu, sigma):
return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
res = mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
res = check_icdf_value(res, value)
return check_icdf_parameters(
res,
sigma > 0,
msg="sigma > 0",
)


class TruncatedNormalRV(RandomVariable):
Expand Down
11 changes: 10 additions & 1 deletion pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from pymc.distributions.dist_math import (
betaln,
binomln,
check_icdf_parameters,
check_icdf_value,
check_parameters,
factln,
log_diff_normal_cdf,
Expand Down Expand Up @@ -820,7 +822,14 @@ def logcdf(value, p):
)

def icdf(value, p):
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
res = at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
res = check_icdf_value(res, value)
return check_icdf_parameters(
res,
0 <= p,
p <= 1,
msg="0 <= p <= 1",
)


class HyperGeometric(Discrete):
Expand Down
39 changes: 32 additions & 7 deletions pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
import warnings

from functools import partial
from typing import Iterable

import numpy as np
Expand Down Expand Up @@ -50,13 +51,21 @@
}


def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""):
"""
Wrap a log probability graph in a CheckParameterValue that asserts several
conditions are True. When conditions are not met a ParameterValueError assertion is
raised, with an optional custom message defined by `msg`
def check_parameters(
expr: Variable,
*conditions: Iterable[Variable],
msg: str = "",
can_be_replaced_by_ninf: bool = True,
):
"""Wrap an expression in a CheckParameterValue that asserts several conditions are met.

When conditions are not met a ParameterValueError assertion is raised,
with an optional custom message defined by `msg`.

When the flag `can_be_replaced_by_ninf` is True (default), PyMC is allowed to replace the
assertion by a switch(condition, expr, -inf). This is used for logp graphs!

Note that check_parameter should not be used to enforce the logic of the logp
Note that check_parameter should not be used to enforce the logic of the
expression under the normal parameter support as it can be disabled by the user via
check_bounds = False in pm.Model()
"""
Expand All @@ -65,7 +74,23 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
]
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
return CheckParameterValue(msg)(logp, all_true_scalar)

return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)


check_icdf_parameters = partial(check_parameters, can_be_replaced_by_ninf=False)


def check_icdf_value(expr: Variable, value: Variable) -> Variable:
"""Wrap icdf expression in nan switch for value."""
value = at.as_tensor_variable(value)
expr = at.switch(
at.and_(value >= 0, value <= 1),
expr,
np.nan,
)
expr.name = "0 <= value <= 1"
return expr


def logpow(x, m):
Expand Down
5 changes: 4 additions & 1 deletion pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ class CheckParameterValue(CheckAndRaise):
Raises `ParameterValueError` if the check is not True.
"""

def __init__(self, msg=""):
__props__ = ("msg", "exc_type", "can_be_replaced_by_ninf")

def __init__(self, msg: str = "", can_be_replaced_by_ninf: bool = False):
super().__init__(ParameterValueError, msg)
self.can_be_replaced_by_ninf = can_be_replaced_by_ninf

def __str__(self):
return f"Check{{{self.msg}}}"
Expand Down
24 changes: 13 additions & 11 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,19 +913,21 @@ def local_remove_check_parameter(fgraph, node):

@node_rewriter(tracks=[CheckParameterValue])
def local_check_parameter_to_ninf_switch(fgraph, node):
if isinstance(node.op, CheckParameterValue):
logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = at.all(logp_conds)
else:
(logp_cond,) = logp_conds
out = at.switch(logp_cond, logp_expr, -np.inf)
out.name = node.op.msg
if not node.op.can_be_replaced_by_ninf:
return None

logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = at.all(logp_conds)
else:
(logp_cond,) = logp_conds
out = at.switch(logp_cond, logp_expr, -np.inf)
out.name = node.op.msg

if out.dtype != node.outputs[0].dtype:
out = at.cast(out, node.outputs[0].dtype)
if out.dtype != node.outputs[0].dtype:
out = at.cast(out, node.outputs[0].dtype)

return [out]
return [out]


pytensor.compile.optdb["canonicalize"].register(
Expand Down
Loading