diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index fbdb01af9..e7b43e679 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -15,18 +15,25 @@ from typing import Sequence, Union +import numpy as np import pymc as pm import pytensor.tensor as pt __all__ = ["R2D2M2CP"] -def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable): +def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask): pi = pt.erfinv(2 * psi - 1) f = (1 / (2 * pi**2 + 1)) ** 0.5 sigma = explained_var**0.5 * f mu = sigma * pi * 2**0.5 - return mu, sigma + if psi_mask is not None: + return ( + pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5), + pt.where(psi_mask, sigma, 0), + ) + else: + return mu, sigma def _R2D2M2CP_beta( @@ -37,6 +44,7 @@ def _R2D2M2CP_beta( phi: pt.TensorVariable, psi: pt.TensorVariable, *, + psi_mask, dims: Union[str, Sequence[str]], centered=False, ): @@ -59,16 +67,141 @@ def _R2D2M2CP_beta( """ tau2 = r2 / (1 - r2) explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1) - mu_param, std_param = _psivar2musigma(psi, explained_variance) + mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask) if not centered: with pm.Model(name): - raw = pm.Normal("raw", dims=dims) + if psi_mask is not None and psi_mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required + r_idx = psi_mask.nonzero() + with pm.Model("raw"): + raw = pm.Normal("masked", shape=len(r_idx[0])) + raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw) + raw = pm.Deterministic("raw", raw, dims=dims) + elif psi_mask is not None: + # all variables are deterministic + raw = pt.zeros_like(mu_param) + else: + raw = pm.Normal("raw", dims=dims) beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) else: - beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) + if psi_mask is not None and psi_mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required + r_idx = psi_mask.nonzero() + with pm.Model(name): + mean = (mu_param / input_sigma)[r_idx] + sigma = (std_param / input_sigma)[r_idx] + masked = pm.Normal( + "masked", + mean, + sigma, + shape=len(r_idx[0]), + ) + beta = pt.set_subtensor(mean, masked) + beta = pm.Deterministic(name, beta, dims=dims) + elif psi_mask is not None: + # all variables are deterministic + beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims) + else: + beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) return beta +def _broadcast_as_dims(*values, dims): + model = pm.modelcontext(None) + shape = [len(model.coords[d]) for d in dims] + ret = tuple(np.broadcast_to(v, shape) for v in values) + # strip output + if len(values) == 1: + ret = ret[0] + return ret + + +def _psi_masked(positive_probs, positive_probs_std, *, dims): + if not ( + isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant) + ): + raise TypeError( + "Only constant values for positive_probs and positive_probs_std are accepted" + ) + positive_probs, positive_probs_std = _broadcast_as_dims( + positive_probs.data, positive_probs_std.data, dims=dims + ) + mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0) + if np.bitwise_and(~mask, positive_probs_std != 0).any(): + raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0") + if (~mask).any() and mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required + r_idx = mask.nonzero() + with pm.Model("psi"): + psi = pm.Beta( + "masked", + mu=positive_probs[r_idx], + sigma=positive_probs_std[r_idx], + shape=len(r_idx[0]), + ) + psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi) + psi = pm.Deterministic("psi", psi, dims=dims) + elif (~mask).all(): + # limit case where all the probs are limit case + psi = pt.as_tensor(positive_probs) + else: + psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims) + mask = None + return mask, psi + + +def _psi(positive_probs, positive_probs_std, *, dims): + if positive_probs_std is not None: + mask, psi = _psi_masked( + positive_probs=pt.as_tensor(positive_probs), + positive_probs_std=pt.as_tensor(positive_probs_std), + dims=dims, + ) + else: + positive_probs = pt.as_tensor(positive_probs) + if not isinstance(positive_probs, pt.Constant): + raise TypeError("Only constant values for positive_probs are allowed") + psi = _broadcast_as_dims(positive_probs.data, dims=dims) + mask = np.atleast_1d(~np.bitwise_or(psi == 1, psi == 0)) + if mask.all(): + mask = None + return mask, psi + + +def _phi( + variables_importance, + variance_explained, + importance_concentration, + *, + dims, +): + *broadcast_dims, dim = dims + model = pm.modelcontext(None) + if variables_importance is not None: + if variance_explained is not None: + raise TypeError("Can't use variable importance with variance explained") + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variable importance with less than two variables") + variables_importance = pt.as_tensor(variables_importance) + if importance_concentration is not None: + variables_importance *= importance_concentration + return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim]) + elif variance_explained is not None: + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variance explained with less than two variables") + phi = pt.as_tensor(variance_explained) + else: + phi = 1 / len(model.coords[dim]) + phi = _broadcast_as_dims(phi, dims=dims) + if importance_concentration is not None: + return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim]) + else: + return phi + + def R2D2M2CP( name, output_sigma, @@ -78,6 +211,7 @@ def R2D2M2CP( r2, variables_importance=None, variance_explained=None, + importance_concentration=None, r2_std=None, positive_probs=0.5, positive_probs_std=None, @@ -102,6 +236,8 @@ def R2D2M2CP( variance_explained : tensor, optional Alternative estimate for variables importance which is point estimate of variance explained, should sum up to one, by default None + importance_concentration : tensor, optional + Confidence around variance explained or variable importance estimate r2_std : tensor, optional Optional uncertainty over :math:`R^2`, by default None positive_probs : tensor, optional @@ -125,8 +261,8 @@ def R2D2M2CP( ----- The R2D2M2CP prior is a modification of R2D2M2 prior. - - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132 - - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine) + - ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132 + - R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine) Examples -------- @@ -259,31 +395,20 @@ def R2D2M2CP( input_sigma = pt.as_tensor(input_sigma) output_sigma = pt.as_tensor(output_sigma) with pm.Model(name) as model: - if variables_importance is not None: - if variance_explained is not None: - raise TypeError("Can't use variable importance with variance explained") - if len(model.coords[dim]) <= 1: - raise TypeError("Can't use variable importance with less than two variables") - phi = pm.Dirichlet( - "phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim] - ) - elif variance_explained is not None: - if len(model.coords[dim]) <= 1: - raise TypeError("Can't use variance explained with less than two variables") - phi = pt.as_tensor(variance_explained) - else: - phi = 1 / len(model.coords[dim]) + if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims): + raise ValueError(f"{dims!r} should be constant length immutable dims") if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) - if positive_probs_std is not None: - psi = pm.Beta( - "psi", - mu=pt.as_tensor(positive_probs), - sigma=pt.as_tensor(positive_probs_std), - dims=broadcast_dims + [dim], - ) - else: - psi = pt.as_tensor(positive_probs) + phi = _phi( + variables_importance=variables_importance, + variance_explained=variance_explained, + importance_concentration=importance_concentration, + dims=dims, + ) + mask, psi = _psi( + positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims + ) + beta = _R2D2M2CP_beta( name, output_sigma, @@ -293,6 +418,7 @@ def R2D2M2CP( psi, dims=broadcast_dims + [dim], centered=centered, + psi_mask=mask, ) resid_sigma = (1 - r2) ** 0.5 * output_sigma return resid_sigma, beta diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index ee28c3e33..4dd81b2b3 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -1,5 +1,6 @@ import numpy as np import pymc as pm +import pytensor import pytest import pymc_experimental as pmx @@ -50,22 +51,35 @@ def r2(self): def r2_std(self, request): return request.param - @pytest.fixture(params=[True, False], ids=["probs", "no-probs"]) + @pytest.fixture(params=["true", "false", "limit-1", "limit-0", "limit-all"]) def positive_probs(self, input_std, request): - if request.param: + if request.param == "true": return np.full_like(input_std, 0.5) - else: + elif request.param == "false": return 0.5 + elif request.param == "limit-1": + ret = np.full_like(input_std, 0.5) + ret[..., 0] = 1 + return ret + elif request.param == "limit-0": + ret = np.full_like(input_std, 0.5) + ret[..., 0] = 0 + return ret + elif request.param == "limit-all": + return np.full_like(input_std, 0) @pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"]) def positive_probs_std(self, positive_probs, request): if request.param: - return np.full_like(positive_probs, 0.1) + std = np.full_like(positive_probs, 0.1) + std[positive_probs == 0] = 0 + std[positive_probs == 1] = 0 + return std else: return None @pytest.fixture(params=[None, "importance", "explained"]) - def phi_args(self, request, input_shape): + def phi_args_base(self, request, input_shape): if input_shape[-1] < 2 and request.param is not None: pytest.skip("not compatible") elif request.param is None: @@ -76,6 +90,16 @@ def phi_args(self, request, input_shape): val = np.full(input_shape, 2) return {"variance_explained": val / val.sum(-1, keepdims=True)} + @pytest.fixture(params=["concentration", "no-concentration"]) + def phi_args(self, request, phi_args_base): + if request.param == "concentration": + phi_args_base["importance_concentration"] = 10 + return phi_args_base + + @pytest.mark.skipif( + pytensor.config.floatX == "float32", + reason="pytensor.config.floatX == 'float32', https://github.com/pymc-devs/pymc/issues/6779", + ) def test_init( self, dims, @@ -101,17 +125,20 @@ def test_init( positive_probs=positive_probs, **phi_args ) + assert not np.isnan(beta.eval()).any() assert eps.eval().shape == output_std.shape assert beta.eval().shape == input_std.shape # r2 rv is only created if r2 std is not None + assert "beta" in model.named_vars assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) - # phi is only created if variable importances is not None and there is more than one var - assert ("beta::phi" in model.named_vars) == ("variables_importance" in phi_args), set( - model.named_vars - ) - assert ("beta::psi" in model.named_vars) == (positive_probs_std is not None), set( - model.named_vars - ) + # phi is only created if variable importance is not None and there is more than one var + assert ("beta::phi" in model.named_vars) == ( + "variables_importance" in phi_args or "importance_concentration" in phi_args + ), set(model.named_vars) + assert ("beta::psi" in model.named_vars) == ( + positive_probs_std is not None and positive_probs_std.any() + ), set(model.named_vars) + assert np.isfinite(sum(model.point_logps().values())) def test_failing_importance(self, dims, input_shape, output_std, input_std): if input_shape[-1] < 2: @@ -163,3 +190,92 @@ def test_failing_mutual_exclusive(self, model: pm.Model): variance_explained=[0.5, 0.5], variables_importance=[1, 1], ) + + def test_limit_case_requires_std_0(self, model: pm.Model): + model.add_coord("a", range(2)) + with pytest.raises(ValueError, match="Can't have both positive_probs"): + pmx.distributions.R2D2M2CP( + "beta", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 0], + positive_probs_std=[0.3, 0.1], + ) + with pytest.raises(ValueError, match="Can't have both positive_probs"): + pmx.distributions.R2D2M2CP( + "beta", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0.1], + ) + + def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool): + model.add_coord("a", range(2)) + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + centered=centered, + ) + pmx.distributions.R2D2M2CP( + "beta1", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 0], + positive_probs_std=[0.3, 0], + centered=centered, + ) + if not centered: + assert "beta0::raw::masked" in model.named_vars, model.named_vars + assert "beta1::raw::masked" in model.named_vars, model.named_vars + else: + assert "beta0::masked" in model.named_vars, model.named_vars + assert "beta1::masked" in model.named_vars, model.named_vars + assert "beta1::psi::masked" in model.named_vars + assert "beta0::psi::masked" in model.named_vars + + def test_zero_length_rvs_not_created(self, model: pm.Model): + model.add_coord("a", range(2)) + # deterministic case which should not have any new variables + b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a") + assert not model.free_RVs, model.free_RVs + + b = pmx.distributions.R2D2M2CP( + "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a" + ) + assert not model.free_RVs, model.free_RVs + + def test_immutable_dims(self, model: pm.Model): + model.add_coord("a", range(2), mutable=True) + model.add_coord("b", range(2), mutable=False) + with pytest.raises(ValueError, match="should be constant length immutable dims"): + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + ) + with pytest.raises(ValueError, match="should be constant length immutable dims"): + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims=("a", "b"), + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + )