diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 63c93f4778..e029a1c6e4 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -79,6 +79,8 @@ def polyagamma_cdf(*args, **kwargs): from pymc.distributions import transforms from pymc.distributions.dist_math import ( SplineWrapper, + check_icdf_parameters, + check_icdf_value, check_parameters, clipped_beta_rvs, i0e, @@ -532,7 +534,13 @@ def logcdf(value, mu, sigma): ) def icdf(value, mu, sigma): - return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value) + res = mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value) + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + sigma > 0, + msg="sigma > 0", + ) class TruncatedNormalRV(RandomVariable): diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 12f57aa697..bb275b24ac 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -36,6 +36,8 @@ from pymc.distributions.dist_math import ( betaln, binomln, + check_icdf_parameters, + check_icdf_value, check_parameters, factln, log_diff_normal_cdf, @@ -820,7 +822,14 @@ def logcdf(value, p): ) def icdf(value, p): - return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64") + res = at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64") + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + 0 <= p, + p <= 1, + msg="0 <= p <= 1", + ) class HyperGeometric(Discrete): diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index fbdea97440..0cf646990a 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -19,6 +19,7 @@ """ import warnings +from functools import partial from typing import Iterable import numpy as np @@ -50,13 +51,21 @@ } -def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""): - """ - Wrap a log probability graph in a CheckParameterValue that asserts several - conditions are True. When conditions are not met a ParameterValueError assertion is - raised, with an optional custom message defined by `msg` +def check_parameters( + expr: Variable, + *conditions: Iterable[Variable], + msg: str = "", + can_be_replaced_by_ninf: bool = True, +): + """Wrap an expression in a CheckParameterValue that asserts several conditions are met. + + When conditions are not met a ParameterValueError assertion is raised, + with an optional custom message defined by `msg`. + + When the flag `can_be_replaced_by_ninf` is True (default), PyMC is allowed to replace the + assertion by a switch(condition, expr, -inf). This is used for logp graphs! - Note that check_parameter should not be used to enforce the logic of the logp + Note that check_parameter should not be used to enforce the logic of the expression under the normal parameter support as it can be disabled by the user via check_bounds = False in pm.Model() """ @@ -65,7 +74,23 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions ] all_true_scalar = at.all([at.all(cond) for cond in conditions_]) - return CheckParameterValue(msg)(logp, all_true_scalar) + + return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar) + + +check_icdf_parameters = partial(check_parameters, can_be_replaced_by_ninf=False) + + +def check_icdf_value(expr: Variable, value: Variable) -> Variable: + """Wrap icdf expression in nan switch for value.""" + value = at.as_tensor_variable(value) + expr = at.switch( + at.and_(value >= 0, value <= 1), + expr, + np.nan, + ) + expr.name = "0 <= value <= 1" + return expr def logpow(x, m): diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index b88d56d3ee..b93962fd56 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -210,8 +210,11 @@ class CheckParameterValue(CheckAndRaise): Raises `ParameterValueError` if the check is not True. """ - def __init__(self, msg=""): + __props__ = ("msg", "exc_type", "can_be_replaced_by_ninf") + + def __init__(self, msg: str = "", can_be_replaced_by_ninf: bool = False): super().__init__(ParameterValueError, msg) + self.can_be_replaced_by_ninf = can_be_replaced_by_ninf def __str__(self): return f"Check{{{self.msg}}}" diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index bca1c7bdca..033fc8aefa 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -913,19 +913,21 @@ def local_remove_check_parameter(fgraph, node): @node_rewriter(tracks=[CheckParameterValue]) def local_check_parameter_to_ninf_switch(fgraph, node): - if isinstance(node.op, CheckParameterValue): - logp_expr, *logp_conds = node.inputs - if len(logp_conds) > 1: - logp_cond = at.all(logp_conds) - else: - (logp_cond,) = logp_conds - out = at.switch(logp_cond, logp_expr, -np.inf) - out.name = node.op.msg + if not node.op.can_be_replaced_by_ninf: + return None + + logp_expr, *logp_conds = node.inputs + if len(logp_conds) > 1: + logp_cond = at.all(logp_conds) + else: + (logp_cond,) = logp_conds + out = at.switch(logp_cond, logp_expr, -np.inf) + out.name = node.op.msg - if out.dtype != node.outputs[0].dtype: - out = at.cast(out, node.outputs[0].dtype) + if out.dtype != node.outputs[0].dtype: + out = at.cast(out, node.outputs[0].dtype) - return [out] + return [out] pytensor.compile.optdb["canonicalize"].register( diff --git a/pymc/testing.py b/pymc/testing.py index e56f3fe88d..a0747bdaca 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -14,7 +14,7 @@ import functools as ft import itertools as it -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pytensor @@ -26,20 +26,23 @@ from pytensor.compile.mode import Mode from pytensor.graph.basic import ancestors from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from scipy import special as sp from scipy import stats as st import pymc as pm -from pymc import logcdf, logp +from pymc import Distribution, logcdf, logp from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.logprob import joint_logp +from pymc.logprob.abstract import icdf from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( compile_pymc, floatX, + inputvars, intX, local_check_parameter_to_ninf_switch, ) @@ -246,17 +249,69 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): return m, param_vars +def create_dist_from_paramdomains( + pymc_dist: Distribution, + paramdomains: Dict[str, Domain], + extra_args: Optional[Dict[str, Any]] = None, +) -> TensorVariable: + """Create a PyMC distribution from a dictionary of parameter domains. + + Returns + ------- + PyMC distribution variable: TensorVariable + Value variable: TensorVariable + """ + if extra_args is None: + extra_args = {} + + param_vars = {} + for param, domain in paramdomains.items(): + param_type = pt.constant(np.asarray(domain.vals[0])).type() + param_type.name = param + param_vars[param] = param_type + + return pymc_dist.dist(**param_vars, **extra_args) + + +def find_invalid_scalar_params( + paramdomains: Dict["str", Domain] +) -> Dict["str", Tuple[Union[None, float], Union[None, float]]]: + """Find invalid parameter values from bounded scalar parameter domains. + + For use in `check_logp`-like testing helpers. + + Returns + ------- + Invalid paramemeter values: + Dictionary mapping each parameter, to a lower and upper invalid values (out of domain). + If no lower or upper invalid values exist, None is returned for that entry. + """ + invalid_params = {} + for param, paramdomain in paramdomains.items(): + lower_edge, upper_edge = None, None + + if np.ndim(paramdomain.lower) == 0: + if np.isfinite(paramdomain.lower): + lower_edge = paramdomain.lower - 1 + + if np.isfinite(paramdomain.upper): + upper_edge = paramdomain.upper + 1 + + invalid_params[param] = (lower_edge, upper_edge) + return invalid_params + + def check_logp( - pymc_dist, - domain, - paramdomains, - scipy_logp, - decimal=None, - n_samples=100, - extra_args=None, - scipy_args=None, - skip_paramdomain_outside_edge_test=False, -): + pymc_dist: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + scipy_logp: Callable, + decimal: Optional[int] = None, + n_samples: int = 100, + extra_args: Optional[Dict[str, Any]] = None, + scipy_args: Optional[Dict[str, Any]] = None, + skip_paramdomain_outside_edge_test: bool = False, +) -> None: """ Generic test for PyMC logp methods @@ -291,122 +346,77 @@ def check_logp( if decimal is None: decimal = select_by_precision(float64=6, float32=3) - if extra_args is None: - extra_args = {} - if scipy_args is None: scipy_args = {} - def logp_reference(args): + def scipy_logp_with_scipy_args(**args): args.update(scipy_args) return scipy_logp(**args) - def _model_input_dict(model, param_vars, point): - """Create a dict with only the necessary, transformed logp inputs.""" - pt_d = {} - for k, v in point.items(): - rv_var = model.named_vars.get(k) - nv = param_vars.get(k, rv_var) - nv = model.rvs_to_values.get(nv, nv) - - transform = model.rvs_to_transforms.get(rv_var, None) - if transform: - # todo: the compiled graph behind this should be cached and - # reused (if it isn't already). - v = transform.forward(rv_var, v).eval() + dist = create_dist_from_paramdomains(pymc_dist, paramdomains, extra_args) + value = dist.type() + value.name = "value" + pymc_dist_logp = logp(dist, value).sum() + pymc_logp = pytensor.function(list(inputvars(pymc_dist_logp)), pymc_dist_logp) - if nv.name in param_vars: - # update the shared parameter variables in `param_vars` - param_vars[nv.name].set_value(v) - else: - # create an argument entry for the (potentially - # transformed) "value" variable - pt_d[nv.name] = v - - return pt_d - - model, param_vars = build_model(pymc_dist, domain, paramdomains, extra_args) - logp_pymc = model.compile_logp(jacobian=False) - - # Test supported value and parameters domain matches scipy + # Test supported value and parameters domain matches Scipy domains = paramdomains.copy() domains["value"] = domain for point in product(domains, n_samples=n_samples): point = dict(point) - pt_d = _model_input_dict(model, param_vars, point) - pt_logp = pm.Point(pt_d, model=model) - pt_ref = pm.Point(point, filter_model_vars=False, model=model) npt.assert_almost_equal( - logp_pymc(pt_logp), - logp_reference(pt_ref), + pymc_logp(**point), + scipy_logp_with_scipy_args(**point), decimal=decimal, err_msg=str(point), ) valid_value = domain.vals[0] valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} - valid_dist = pymc_dist.dist(**valid_params, **extra_args) + valid_params["value"] = valid_value # Test pymc distribution raises ParameterValueError for scalar parameters outside # the supported domain edges (excluding edges) if not skip_paramdomain_outside_edge_test: - # Step1: collect potential invalid parameters - invalid_params = {param: [None, None] for param in paramdomains} - for param, paramdomain in paramdomains.items(): - if np.ndim(paramdomain.lower) != 0: - continue - if np.isfinite(paramdomain.lower): - invalid_params[param][0] = paramdomain.lower - 1 - if np.isfinite(paramdomain.upper): - invalid_params[param][1] = paramdomain.upper + 1 + invalid_params = find_invalid_scalar_params(paramdomains) - # Step2: test invalid parameters, one a time for invalid_param, invalid_edges in invalid_params.items(): for invalid_edge in invalid_edges: if invalid_edge is None: continue - test_params = valid_params.copy() # Shallow copy should be okay - test_params[invalid_param] = pt.as_tensor_variable(invalid_edge) - # We need to remove `Assert`s introduced by checks like - # `assert_negative_support` and disable test values; - # otherwise, we won't be able to create the `RandomVariable` - with pytensor.config.change_flags(compute_test_value="off"): - invalid_dist = pymc_dist.dist(**test_params, **extra_args) - with pytensor.config.change_flags(mode=Mode("py")): - with pytest.raises(ParameterValueError): - logp(invalid_dist, valid_value).eval() - pytest.fail(f"test_params={test_params}, valid_value={valid_value}") + + point = valid_params.copy() # Shallow copy should be okay + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_logp(**point) + pytest.fail(f"test_params={point}") # Test that values outside of scalar domain support evaluate to -np.inf - if np.ndim(domain.lower) != 0: - return - invalid_values = [None, None] - if np.isfinite(domain.lower): - invalid_values[0] = domain.lower - 1 - if np.isfinite(domain.upper): - invalid_values[1] = domain.upper + 1 + invalid_values = find_invalid_scalar_params({"value": domain})["value"] for invalid_value in invalid_values: if invalid_value is None: continue - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logp(valid_dist, invalid_value).eval(), - -np.inf, - err_msg=str(invalid_value), - ) + + point = valid_params.copy() + point["value"] = invalid_value + npt.assert_equal( + pymc_logp(**point), + -np.inf, + err_msg=str(point), + ) def check_logcdf( - pymc_dist, - domain, - paramdomains, - scipy_logcdf, - decimal=None, - n_samples=100, - skip_paramdomain_inside_edge_test=False, - skip_paramdomain_outside_edge_test=False, -): + pymc_dist: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + scipy_logcdf: Callable, + decimal: Optional[int] = None, + n_samples: int = 100, + skip_paramdomain_inside_edge_test: bool = False, + skip_paramdomain_outside_edge_test: bool = False, +) -> None: """ Generic test for PyMC logcdf methods @@ -448,133 +458,194 @@ def check_logcdf( returns -inf for invalid parameter values outside the supported domain edge """ + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(pymc_dist, paramdomains) + value = dist.type() + value.name = "value" + dist_logcdf = logcdf(dist, value) + pymc_logcdf = pytensor.function(list(inputvars(dist_logcdf)), dist_logcdf) + # Test pymc and scipy distributions match for values and parameters # within the supported domain edges (excluding edges) if not skip_paramdomain_inside_edge_test: domains = paramdomains.copy() domains["value"] = domain - - model, param_vars = build_model(pymc_dist, domain, paramdomains) - rv = model["value"] - value = model.rvs_to_values[rv] - pymc_logcdf = model.compile_fn(logcdf(rv, value)) - - if decimal is None: - decimal = select_by_precision(float64=6, float32=3) - for point in product(domains, n_samples=n_samples): - params = dict(point) - scipy_eval = scipy_logcdf(**params) - - value = params.pop("value") - # Update shared parameter variables in pymc_logcdf function - for param_name, param_value in params.items(): - param_vars[param_name].set_value(param_value) - pymc_eval = pymc_logcdf({"value": value}) - - params["value"] = value # for displaying in err_msg + point = dict(point) npt.assert_almost_equal( - pymc_eval, - scipy_eval, + pymc_logcdf(**point), + scipy_logcdf(**point), decimal=decimal, - err_msg=str(params), + err_msg=str(point), ) valid_value = domain.vals[0] valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} - valid_dist = pymc_dist.dist(**valid_params) + valid_params["value"] = valid_value # Test pymc distribution raises ParameterValueError for parameters outside the # supported domain edges (excluding edges) if not skip_paramdomain_outside_edge_test: - # Step1: collect potential invalid parameters - invalid_params = {param: [None, None] for param in paramdomains} - for param, paramdomain in paramdomains.items(): - if np.isfinite(paramdomain.lower): - invalid_params[param][0] = paramdomain.lower - 1 - if np.isfinite(paramdomain.upper): - invalid_params[param][1] = paramdomain.upper + 1 - # Step2: test invalid parameters, one a time + invalid_params = find_invalid_scalar_params(paramdomains) + for invalid_param, invalid_edges in invalid_params.items(): for invalid_edge in invalid_edges: - if invalid_edge is not None: - test_params = valid_params.copy() # Shallow copy should be okay - test_params[invalid_param] = pt.as_tensor_variable(invalid_edge) - # We need to remove `Assert`s introduced by checks like - # `assert_negative_support` and disable test values; - # otherwise, we won't be able to create the - # `RandomVariable` - with pytensor.config.change_flags(compute_test_value="off"): - invalid_dist = pymc_dist.dist(**test_params) - with pytensor.config.change_flags(mode=Mode("py")): - with pytest.raises(ParameterValueError): - logcdf(invalid_dist, valid_value).eval() - - # Test that values below domain edge evaluate to -np.inf - if np.isfinite(domain.lower): - below_domain = domain.lower - 1 - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logcdf(valid_dist, below_domain).eval(), - -np.inf, - err_msg=str(below_domain), - ) + if invalid_edge is None: + continue - # Test that values above domain edge evaluate to 0 - if np.isfinite(domain.upper): - above_domain = domain.upper + 1 - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logcdf(valid_dist, above_domain).eval(), - 0, - err_msg=str(above_domain), - ) + point = valid_params.copy() + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_logcdf(**point) + pytest.fail(f"test_params={point}") + + # Test that values below domain edge evaluate to -np.inf, and above evaluates to 0 + invalid_lower, invalid_upper = find_invalid_scalar_params({"value": domain})["value"] + if invalid_lower is not None: + point = valid_params.copy() + point["value"] = invalid_lower + npt.assert_equal( + pymc_logcdf(**point), + -np.inf, + err_msg=str(point), + ) + if invalid_upper is not None: + point = valid_params.copy() + point["value"] = invalid_upper + npt.assert_equal( + pymc_logcdf(**point), + 0, + err_msg=str(point), + ) - # Test that method works with multiple values or raises informative TypeError - valid_dist = pymc_dist.dist(**valid_params, size=2) - with pytensor.config.change_flags(mode=Mode("py")): - try: - logcdf(valid_dist, np.array([valid_value, valid_value])).eval() - except TypeError as err: - assert str(err).endswith( - "logcdf expects a scalar value but received a 1-dimensional object." + +def check_icdf( + pymc_dist: Distribution, + paramdomains: Dict[str, Domain], + scipy_icdf: Callable, + decimal: Optional[int] = None, + n_samples: int = 100, +) -> None: + """ + Generic test for PyMC icdf methods + + The following tests are performed by default: + 1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar + results for parameters inside the supported edges. + Edges are excluded by default, but can be artificially included by + creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`) + 2. Test PyMC icdf method raises for invalid parameter values + outside the supported edges. + 3. Test PyMC icdf method returns np.nan for values below 0 or above 1, + when using valid parameters. + + Parameters + ---------- + pymc_dist: PyMC distribution + paramdomains : Dictionary of Parameter : Domain pairs + Supported domains of distribution parameters + scipy_icdf : Scipy icdf method + Scipy icdf (ppp) method of equivalent pymc_dist distribution + decimal : int, optional + Level of precision with which pymc_dist and scipy_icdf are compared. + Defaults to 6 for float64 and 3 for float32 + n_samples : int + Upper limit on the number of valid domain and value combinations that + are compared between pymc and scipy methods. If n_samples is below the + total number of combinations, a random subset is evaluated. Setting + n_samples = -1, will return all possible combinations. Defaults to 100 + + """ + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(pymc_dist, paramdomains) + q = pt.scalar(dtype="float64", name="q") + dist_icdf = icdf(dist, q) + pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf) + + # Test pymc and scipy distributions match for values and parameters + # within the supported domain edges (excluding edges) + domains = paramdomains.copy() + domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1]) # Values we test the icdf at + domains["q"] = domain + + for point in product(domains, n_samples=n_samples): + point = dict(point) + npt.assert_almost_equal( + pymc_icdf(**point), + scipy_icdf(**point), + decimal=decimal, + err_msg=str(point), + ) + + valid_value = domain.vals[0] + valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} + valid_params["q"] = valid_value + + # Test pymc distribution raises ParameterValueError for parameters outside the + # supported domain edges (excluding edges) + invalid_params = find_invalid_scalar_params(paramdomains) + for invalid_param, invalid_edges in invalid_params.items(): + for invalid_edge in invalid_edges: + if invalid_edge is None: + continue + + point = valid_params.copy() + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_icdf(**point) + pytest.fail(f"test_params={point}") + + # Test that values below 0 or above 1 evaluate to nan + invalid_values = find_invalid_scalar_params({"q": domain})["q"] + for invalid_value in invalid_values: + if invalid_value is not None: + point = valid_params.copy() + point["q"] = invalid_value + npt.assert_equal( + pymc_icdf(**point), + np.nan, + err_msg=str(point), ) def check_selfconsistency_discrete_logcdf( - distribution, - domain, - paramdomains, - decimal=None, - n_samples=100, -): + distribution: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + decimal: Optional[int] = None, + n_samples: int = 100, +) -> None: """ - Check that logcdf of discrete distributions matches sum of logps up to value + Check that logcdf of discrete distributions matches sum of logps up to value. """ - domains = paramdomains.copy() - domains["value"] = domain if decimal is None: decimal = select_by_precision(float64=6, float32=3) - model, param_vars = build_model(distribution, domain, paramdomains) - rv = model["value"] - value = model.rvs_to_values[rv] - dist_logcdf = model.compile_fn(logcdf(rv, value)) - dist_logp = model.compile_fn(logp(rv, value)) + dist = create_dist_from_paramdomains(distribution, paramdomains) + value = dist.type() + value.name = "value" + dist_logp = logp(dist, value) + dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp) + + dist_logcdf = logcdf(dist, value) + dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf) + + domains = paramdomains.copy() + domains["value"] = domain for point in product(domains, n_samples=n_samples): - params = dict(point) - value = params.pop("value") + point = dict(point) + value = point.pop("value") values = np.arange(domain.lower, value + 1) - # Update shared parameter variables in logp/logcdf function - for param_name, param_value in params.items(): - param_vars[param_name].set_value(param_value) - with pytensor.config.change_flags(mode=Mode("py")): npt.assert_almost_equal( - dist_logcdf({"value": value}), - sp.logsumexp([dist_logp({"value": value}) for value in values]), + dist_logcdf_fn(**point, value=value), + sp.logsumexp([dist_logp_fn(value=value, **point) for value in values]), decimal=decimal, err_msg=str(point), ) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 6ad652594b..bf4f349e4a 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -43,6 +43,7 @@ Runif, Unit, assert_moment_is_expected, + check_icdf, check_logcdf, check_logp, continuous_random_tester, @@ -159,14 +160,6 @@ def laplace_asymmetric_logpdf(value, kappa, b, mu): return lPx -def beta_mu_sigma(value, mu, sigma): - kappa = mu * (1 - mu) / sigma**2 - 1 - if kappa > 0: - return st.beta.logpdf(value, mu * kappa, (1 - mu) * kappa) - else: - return -np.inf - - class TestMatchesScipy: def test_uniform(self): check_logp( @@ -278,6 +271,11 @@ def test_normal(self): lambda value, mu, sigma: st.norm.logcdf(value, mu, sigma), decimal=select_by_precision(float64=6, float32=1), ) + check_icdf( + pm.Normal, + {"mu": R, "sigma": Rplus}, + lambda q, mu, sigma: st.norm.ppf(q, mu, sigma), + ) def test_half_normal(self): check_logp( @@ -367,10 +365,18 @@ def test_beta_logp(self): {"alpha": Rplus, "beta": Rplus}, lambda value, alpha, beta: st.beta.logpdf(value, alpha, beta), ) + + def beta_mu_sigma(value, mu, sigma): + kappa = mu * (1 - mu) / sigma**2 - 1 + return st.beta.logpdf(value, mu * kappa, (1 - mu) * kappa) + + # The mu/sigma parametrization is not always valid + safe_mu_domain = Domain([0, 0.3, 0.5, 0.8, 1]) + safe_sigma_domain = Domain([0, 0.05, 0.1, np.inf]) check_logp( pm.Beta, Unit, - {"mu": Unit, "sigma": Rplus}, + {"mu": safe_mu_domain, "sigma": safe_sigma_domain}, beta_mu_sigma, ) @@ -2269,21 +2275,3 @@ def dist(cls, **kwargs): extra_args={"rng": pytensor.shared(rng)}, ref_rand=ref_rand, ) - - -class TestICDF: - @pytest.mark.parametrize( - "dist_params, obs, size", - [ - ((0, 1), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()), - ((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()), - ((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), (2, 3)), - ], - ) - def test_normal_icdf(self, dist_params, obs, size): - dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size) - dist_params = dict(zip(dist_params_at, dist_params)) - - x = Normal.dist(*dist_params_at, size=size_at) - - scipy_logprob_tester(x, obs, dist_params, test_fn=st.norm.ppf, test="icdf") diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 7c929e3f1f..1330b5ba54 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -51,6 +51,7 @@ UnitSortedVector, Vector, assert_moment_is_expected, + check_icdf, check_logcdf, check_logp, check_selfconsistency_discrete_logcdf, @@ -143,15 +144,14 @@ def test_geometric(self): Nat, {"p": Unit}, ) + check_icdf( + pm.Geometric, + {"p": Unit}, + st.geom.ppf, + ) def test_hypergeometric(self): - def modified_scipy_hypergeom_logpmf(value, N, k, n): - # Convert nan to -np.inf - original_res = st.hypergeom.logpmf(value, N, k, n) - return original_res if not np.isnan(original_res) else -np.inf - def modified_scipy_hypergeom_logcdf(value, N, k, n): - # Convert nan to -np.inf original_res = st.hypergeom.logcdf(value, N, k, n) # Correct for scipy bug in logcdf method (see https://github.com/scipy/scipy/issues/13280) @@ -160,24 +160,27 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n): if np.all(np.isnan(pmfs)): original_res = np.nan - return original_res if not np.isnan(original_res) else -np.inf + return original_res + + N_domain = Domain([0, 10, 20, 30, np.inf], dtype="int64") + n_domain = k_domain = Domain([0, 1, 2, 3, np.inf], dtype="int64") check_logp( pm.HyperGeometric, Nat, - {"N": NatSmall, "k": NatSmall, "n": NatSmall}, - modified_scipy_hypergeom_logpmf, + {"N": N_domain, "k": k_domain, "n": n_domain}, + lambda value, N, k, n: st.hypergeom.logpmf(value, N, k, n), ) check_logcdf( pm.HyperGeometric, Nat, - {"N": NatSmall, "k": NatSmall, "n": NatSmall}, + {"N": N_domain, "k": k_domain, "n": n_domain}, modified_scipy_hypergeom_logcdf, ) check_selfconsistency_discrete_logcdf( pm.HyperGeometric, Nat, - {"N": NatSmall, "k": NatSmall, "n": NatSmall}, + {"N": N_domain, "k": k_domain, "n": n_domain}, ) @pytest.mark.xfail( @@ -535,15 +538,17 @@ def test_categorical_p_not_normalized_symbolic(self): @pytest.mark.parametrize("n", [2, 3, 4]) def test_orderedlogistic(self, n): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) - warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) - check_logp( - pm.OrderedLogistic, - Domain(range(n), dtype="int64", edges=(None, None)), - {"eta": R, "cutpoints": Vector(R, n - 1)}, - lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints), - ) + cutpoints_domain = Vector(R, n - 1) + # Filter out invalid non-monotonic values + cutpoints_domain.vals = [v for v in cutpoints_domain.vals if np.all(np.diff(v) > 0)] + assert len(cutpoints_domain.vals) > 0 + + check_logp( + pm.OrderedLogistic, + Domain(range(n), dtype="int64", edges=(None, None)), + {"eta": R, "cutpoints": cutpoints_domain}, + lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints), + ) @pytest.mark.parametrize("n", [2, 3, 4]) def test_orderedprobit(self, n): @@ -1149,29 +1154,3 @@ def test_shape_inputs(self, eta, cutpoints, sigma, expected): ) p = categorical.owner.inputs[3].eval() assert p.shape == expected - - -class TestICDF: - @pytest.mark.parametrize( - "dist_params, obs, size", - [ - ((0.1,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), ()), - ((0.5,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), (3, 2)), - ( - (np.array([0.0, 0.2, 0.5, 1.0]),), - np.array([0.7, 0.7, 0.7, 0.7], dtype=np.int64), - (), - ), - ], - ) - def test_geometric_icdf(self, dist_params, obs, size): - dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size) - dist_params = dict(zip(dist_params_at, dist_params)) - - x = Geometric.dist(*dist_params_at, size=size_at) - - def scipy_geom_icdf(value, p): - # Scipy ppf returns floats - return st.geom.ppf(value, p).astype(value.dtype) - - scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_geom_icdf, test="icdf") diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index ec7e886aee..9455a3f1cc 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -174,7 +174,6 @@ def test_truncation_discrete_random(op_type, lower, upper): x = geometric_op(p, name="x", size=500) xt = Truncated.dist(x, lower=lower, upper=upper) assert isinstance(xt.owner.op, TruncatedRV) - assert xt.type.dtype == x.type.dtype xt_draws = draw(xt) assert np.all(xt_draws >= lower) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index ddaf86ed47..0fb90c6a51 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -326,6 +326,21 @@ def test_check_bounds_flag(self): with m: assert np.all(compile_pymc([], bound)() == -np.inf) + def test_check_parameters_can_be_replaced_by_ninf(self): + expr = at.vector("expr", shape=(3,)) + cond = at.ge(expr, 0) + + final_expr = check_parameters(expr, cond, can_be_replaced_by_ninf=True) + fn = compile_pymc([expr], final_expr) + np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3]) + np.testing.assert_array_equal(fn(expr=[-1, 2, 3]), [-np.inf, -np.inf, -np.inf]) + + final_expr = check_parameters(expr, cond, msg="test", can_be_replaced_by_ninf=False) + fn = compile_pymc([expr], final_expr) + np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3]) + with pytest.raises(ParameterValueError, match="test"): + fn([-1, 2, 3]) + def test_compile_pymc_sets_rng_updates(self): rng = pytensor.shared(np.random.default_rng(0)) x = pm.Normal.dist(rng=rng)