From df05be621212b5c0affd8edc001df85e220aa8fc Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Wed, 31 May 2023 23:42:53 +0300 Subject: [PATCH 01/19] add helper function to initialize masked psi --- .../distributions/multivariate/r2d2m2cp.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index fbdb01af9..4d24d3b55 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -15,6 +15,7 @@ from typing import Sequence, Union +import numpy as np import pymc as pm import pytensor.tensor as pt @@ -69,6 +70,41 @@ def _R2D2M2CP_beta( return beta +def _broadcast_as_dims(*values, dims): + model = pm.modelcontext(None) + shape = [len(model.coords[d]) for d in dims] + return tuple(np.broadcast_to(v, shape) for v in values) + + +def _psi_masked(positive_probs, positive_probs_std, *, dims): + if not ( + isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant) + ): + raise TypeError( + "Only constant values for positive_probs and positive_probs_std are accepted" + ) + positive_probs, positive_probs_std = _broadcast_as_dims( + positive_probs.data, positive_probs_std.data, dims=dims + ) + mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0) + if np.bitwise_and(~mask, positive_probs_std != 0).any(): + raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0") + if (~mask).any(): + r_idx = mask.nonzero() + with pm.Model("psi"): + psi = pm.Beta( + "masked", + mu=positive_probs[r_idx], + sigma=positive_probs_std[r_idx], + shape=len(r_idx[0]), + ) + psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi) + psi = pm.Deterministic("psi", psi, dims=dims) + else: + psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims) + return mask, psi + + def R2D2M2CP( name, output_sigma, @@ -258,6 +294,7 @@ def R2D2M2CP( *broadcast_dims, dim = dims input_sigma = pt.as_tensor(input_sigma) output_sigma = pt.as_tensor(output_sigma) + positive_probs = pt.as_tensor(positive_probs) with pm.Model(name) as model: if variables_importance is not None: if variance_explained is not None: @@ -272,7 +309,7 @@ def R2D2M2CP( raise TypeError("Can't use variance explained with less than two variables") phi = pt.as_tensor(variance_explained) else: - phi = 1 / len(model.coords[dim]) + phi = pt.as_tensor(1 / len(model.coords[dim])) if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) if positive_probs_std is not None: From 32d4b7ac097700930562215671cf75c7350d4d12 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 11:49:02 +0300 Subject: [PATCH 02/19] split initialization functions into helpers --- .../distributions/multivariate/r2d2m2cp.py | 99 +++++++++++++------ 1 file changed, 69 insertions(+), 30 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 4d24d3b55..06e6072de 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -22,12 +22,18 @@ __all__ = ["R2D2M2CP"] -def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable): +def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask): pi = pt.erfinv(2 * psi - 1) f = (1 / (2 * pi**2 + 1)) ** 0.5 sigma = explained_var**0.5 * f mu = sigma * pi * 2**0.5 - return mu, sigma + if psi_mask is not None: + return ( + pt.where(psi_mask, mu, explained_var**0.5), + pt.where(psi_mask, sigma, 0), + ) + else: + return mu, sigma def _R2D2M2CP_beta( @@ -38,6 +44,7 @@ def _R2D2M2CP_beta( phi: pt.TensorVariable, psi: pt.TensorVariable, *, + psi_mask, dims: Union[str, Sequence[str]], centered=False, ): @@ -60,7 +67,7 @@ def _R2D2M2CP_beta( """ tau2 = r2 / (1 - r2) explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1) - mu_param, std_param = _psivar2musigma(psi, explained_variance) + mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask) if not centered: with pm.Model(name): raw = pm.Normal("raw", dims=dims) @@ -102,9 +109,53 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims): psi = pm.Deterministic("psi", psi, dims=dims) else: psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims) + mask = None return mask, psi +def _psi(positive_probs, positive_probs_std, *, dims): + if positive_probs_std is not None: + mask, psi = _psi_masked( + "psi", + mu=pt.as_tensor(positive_probs), + sigma=pt.as_tensor(positive_probs_std), + dims=dims, + ) + else: + positive_probs = pt.as_tensor(positive_probs) + if not isinstance(positive_probs, pt.Constant): + raise TypeError("Only constant values for positive_probs are allowed") + psi = _broadcast_as_dims(positive_probs.data, dims=dims) + mask = psi != 1 + if (mask).all(): + mask = None + return mask, psi + + +def _phi( + variables_importance, + variance_explained, + variance_explained_concentration, + *, + dims, +): + *broadcast_dims, dim = dims + model = pm.modelcontext(None) + if variables_importance is not None: + if variance_explained is not None: + raise TypeError("Can't use variable importance with variance explained") + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variable importance with less than two variables") + phi = pm.Dirichlet("phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]) + elif variance_explained is not None: + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variance explained with less than two variables") + phi = pt.as_tensor(variance_explained) + else: + phi = pt.as_tensor(1 / len(model.coords[dim])) + return phi + + def R2D2M2CP( name, output_sigma, @@ -114,6 +165,7 @@ def R2D2M2CP( r2, variables_importance=None, variance_explained=None, + variance_explained_concentration=None, r2_std=None, positive_probs=0.5, positive_probs_std=None, @@ -138,6 +190,8 @@ def R2D2M2CP( variance_explained : tensor, optional Alternative estimate for variables importance which is point estimate of variance explained, should sum up to one, by default None + variance_explained_concentration : tensor, optional + Confidence around variance explained estimate r2_std : tensor, optional Optional uncertainty over :math:`R^2`, by default None positive_probs : tensor, optional @@ -161,8 +215,8 @@ def R2D2M2CP( ----- The R2D2M2CP prior is a modification of R2D2M2 prior. - - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132 - - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine) + - ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132 + - R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine) Examples -------- @@ -294,33 +348,17 @@ def R2D2M2CP( *broadcast_dims, dim = dims input_sigma = pt.as_tensor(input_sigma) output_sigma = pt.as_tensor(output_sigma) - positive_probs = pt.as_tensor(positive_probs) - with pm.Model(name) as model: - if variables_importance is not None: - if variance_explained is not None: - raise TypeError("Can't use variable importance with variance explained") - if len(model.coords[dim]) <= 1: - raise TypeError("Can't use variable importance with less than two variables") - phi = pm.Dirichlet( - "phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim] - ) - elif variance_explained is not None: - if len(model.coords[dim]) <= 1: - raise TypeError("Can't use variance explained with less than two variables") - phi = pt.as_tensor(variance_explained) - else: - phi = pt.as_tensor(1 / len(model.coords[dim])) + with pm.Model(name): if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) - if positive_probs_std is not None: - psi = pm.Beta( - "psi", - mu=pt.as_tensor(positive_probs), - sigma=pt.as_tensor(positive_probs_std), - dims=broadcast_dims + [dim], - ) - else: - psi = pt.as_tensor(positive_probs) + phi = _phi( + variables_importance=variables_importance, + variance_explained=variance_explained, + ) + mask, psi = _psi( + positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims + ) + beta = _R2D2M2CP_beta( name, output_sigma, @@ -330,6 +368,7 @@ def R2D2M2CP( psi, dims=broadcast_dims + [dim], centered=centered, + psi_mask=mask, ) resid_sigma = (1 - r2) ** 0.5 * output_sigma return resid_sigma, beta From 7f811de045588852f9ebc7bd8ef9e9c3868c7a0c Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:11:57 +0300 Subject: [PATCH 03/19] add importance concentration parameter --- .../distributions/multivariate/r2d2m2cp.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 06e6072de..7d6162f2e 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -135,7 +135,7 @@ def _psi(positive_probs, positive_probs_std, *, dims): def _phi( variables_importance, variance_explained, - variance_explained_concentration, + importance_concentration, *, dims, ): @@ -146,14 +146,20 @@ def _phi( raise TypeError("Can't use variable importance with variance explained") if len(model.coords[dim]) <= 1: raise TypeError("Can't use variable importance with less than two variables") - phi = pm.Dirichlet("phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]) + variables_importance = pt.as_tensor(variables_importance) + if importance_concentration is not None: + variables_importance *= importance_concentration + return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim]) elif variance_explained is not None: if len(model.coords[dim]) <= 1: raise TypeError("Can't use variance explained with less than two variables") phi = pt.as_tensor(variance_explained) else: phi = pt.as_tensor(1 / len(model.coords[dim])) - return phi + if importance_concentration is not None: + return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim]) + else: + return phi def R2D2M2CP( @@ -165,7 +171,7 @@ def R2D2M2CP( r2, variables_importance=None, variance_explained=None, - variance_explained_concentration=None, + importance_concentration=None, r2_std=None, positive_probs=0.5, positive_probs_std=None, @@ -190,8 +196,8 @@ def R2D2M2CP( variance_explained : tensor, optional Alternative estimate for variables importance which is point estimate of variance explained, should sum up to one, by default None - variance_explained_concentration : tensor, optional - Confidence around variance explained estimate + importance_concentration : tensor, optional + Confidence around variance explained or variable importance estimate r2_std : tensor, optional Optional uncertainty over :math:`R^2`, by default None positive_probs : tensor, optional From 51b81fadb3e1a50a88eef00061f1f613c6153a6b Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:18:11 +0300 Subject: [PATCH 04/19] rework non centered case for beta init --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 7d6162f2e..552801d40 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -70,7 +70,14 @@ def _R2D2M2CP_beta( mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask) if not centered: with pm.Model(name): - raw = pm.Normal("raw", dims=dims) + if psi_mask is not None: + r_idx = psi_mask.nonzero() + with pm.Model("raw"): + raw = pm.Normal("masked", shape=len(r_idx[0])) + raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], psi) + raw = pm.Deterministic("raw", raw, dims=dims) + else: + raw = pm.Normal("raw", dims=dims) beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) else: beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) From b84541555ed27c25bca93997e2ca055e2109cdf4 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:21:32 +0300 Subject: [PATCH 05/19] fix sign for mean sigma helper --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 552801d40..cbe5ff9af 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -29,7 +29,7 @@ def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, ps mu = sigma * pi * 2**0.5 if psi_mask is not None: return ( - pt.where(psi_mask, mu, explained_var**0.5), + pt.where(psi_mask, mu, pt.sign(mu) * explained_var**0.5), pt.where(psi_mask, sigma, 0), ) else: From bea2248d0df8d4e32170df22ee25be323a4583ca Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:26:29 +0300 Subject: [PATCH 06/19] add centered parametrization for the limit case --- .../distributions/multivariate/r2d2m2cp.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index cbe5ff9af..3b1d0f646 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -80,7 +80,21 @@ def _R2D2M2CP_beta( raw = pm.Normal("raw", dims=dims) beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) else: - beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) + if psi_mask is not None: + r_idx = psi_mask.nonzero() + with pm.Model(name): + mean = (mu_param / input_sigma)[r_idx] + sigma = (std_param / input_sigma)[r_idx] + masked = pm.Normal( + "masked", + mean, + sigma, + shape=len(r_idx[0]), + ) + beta = pt.set_subtensor(mean, masked) + beta = pm.Deterministic(name, beta, dims=dims) + else: + beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) return beta From 3b82fe67f7abd73020387d26b3d58c4483bc7bd1 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:28:20 +0300 Subject: [PATCH 07/19] fix typo --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 3b1d0f646..a219d0667 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -381,6 +381,7 @@ def R2D2M2CP( phi = _phi( variables_importance=variables_importance, variance_explained=variance_explained, + importance_concentration=importance_concentration, ) mask, psi = _psi( positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims From 3a0c1bd09d0021657558630d2ba49aa772589d0b Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:31:22 +0300 Subject: [PATCH 08/19] fix typo --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index a219d0667..c17e566e5 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -382,6 +382,7 @@ def R2D2M2CP( variables_importance=variables_importance, variance_explained=variance_explained, importance_concentration=importance_concentration, + dims=dims, ) mask, psi = _psi( positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims From bab5149598539632be865d8e523c87d98177291c Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 14:40:27 +0300 Subject: [PATCH 09/19] fix implementation for non limit cases --- .../distributions/multivariate/r2d2m2cp.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index c17e566e5..4fe2fc5d3 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -101,7 +101,11 @@ def _R2D2M2CP_beta( def _broadcast_as_dims(*values, dims): model = pm.modelcontext(None) shape = [len(model.coords[d]) for d in dims] - return tuple(np.broadcast_to(v, shape) for v in values) + ret = tuple(np.broadcast_to(v, shape) for v in values) + # strip output + if len(values) == 1: + ret = ret[0] + return ret def _psi_masked(positive_probs, positive_probs_std, *, dims): @@ -137,9 +141,8 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims): def _psi(positive_probs, positive_probs_std, *, dims): if positive_probs_std is not None: mask, psi = _psi_masked( - "psi", - mu=pt.as_tensor(positive_probs), - sigma=pt.as_tensor(positive_probs_std), + positive_probs=pt.as_tensor(positive_probs), + positive_probs_std=pt.as_tensor(positive_probs_std), dims=dims, ) else: @@ -147,7 +150,7 @@ def _psi(positive_probs, positive_probs_std, *, dims): if not isinstance(positive_probs, pt.Constant): raise TypeError("Only constant values for positive_probs are allowed") psi = _broadcast_as_dims(positive_probs.data, dims=dims) - mask = psi != 1 + mask = np.atleast_1d(psi != 1) if (mask).all(): mask = None return mask, psi From 36ff67ef8f86844bfa32093cf49c5d9c86f42c1f Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 19:17:32 +0300 Subject: [PATCH 10/19] make positive tests pass --- .../distributions/multivariate/r2d2m2cp.py | 9 +++-- .../tests/distributions/test_multivariate.py | 38 +++++++++++++++---- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 4fe2fc5d3..5befe7f87 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -74,7 +74,7 @@ def _R2D2M2CP_beta( r_idx = psi_mask.nonzero() with pm.Model("raw"): raw = pm.Normal("masked", shape=len(r_idx[0])) - raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], psi) + raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw) raw = pm.Deterministic("raw", raw, dims=dims) else: raw = pm.Normal("raw", dims=dims) @@ -150,8 +150,8 @@ def _psi(positive_probs, positive_probs_std, *, dims): if not isinstance(positive_probs, pt.Constant): raise TypeError("Only constant values for positive_probs are allowed") psi = _broadcast_as_dims(positive_probs.data, dims=dims) - mask = np.atleast_1d(psi != 1) - if (mask).all(): + mask = np.atleast_1d(~np.bitwise_or(psi == 1, psi == 0)) + if mask.all(): mask = None return mask, psi @@ -179,7 +179,8 @@ def _phi( raise TypeError("Can't use variance explained with less than two variables") phi = pt.as_tensor(variance_explained) else: - phi = pt.as_tensor(1 / len(model.coords[dim])) + phi = 1 / len(model.coords[dim]) + phi = _broadcast_as_dims(phi, dims=dims) if importance_concentration is not None: return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim]) else: diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index ee28c3e33..5dbac29e7 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -50,22 +50,33 @@ def r2(self): def r2_std(self, request): return request.param - @pytest.fixture(params=[True, False], ids=["probs", "no-probs"]) + @pytest.fixture(params=["true", "false", "limit-1", "limit-0"]) def positive_probs(self, input_std, request): - if request.param: + if request.param == "true": return np.full_like(input_std, 0.5) - else: + elif request.param == "false": return 0.5 + elif request.param == "limit-1": + ret = np.full_like(input_std, 0.5) + ret[..., 0] = 1 + return ret + elif request.param == "limit-0": + ret = np.full_like(input_std, 0.5) + ret[..., 0] = 0 + return ret @pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"]) def positive_probs_std(self, positive_probs, request): if request.param: - return np.full_like(positive_probs, 0.1) + std = np.full_like(positive_probs, 0.1) + std[positive_probs == 0] = 0 + std[positive_probs == 1] = 0 + return std else: return None @pytest.fixture(params=[None, "importance", "explained"]) - def phi_args(self, request, input_shape): + def phi_args_base(self, request, input_shape): if input_shape[-1] < 2 and request.param is not None: pytest.skip("not compatible") elif request.param is None: @@ -76,6 +87,12 @@ def phi_args(self, request, input_shape): val = np.full(input_shape, 2) return {"variance_explained": val / val.sum(-1, keepdims=True)} + @pytest.fixture(params=["concentration", "no-concentration"]) + def phi_args(self, request, phi_args_base): + if request.param == "concentration": + phi_args_base["importance_concentration"] = 10 + return phi_args_base + def test_init( self, dims, @@ -106,12 +123,13 @@ def test_init( # r2 rv is only created if r2 std is not None assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) # phi is only created if variable importances is not None and there is more than one var - assert ("beta::phi" in model.named_vars) == ("variables_importance" in phi_args), set( - model.named_vars - ) + assert ("beta::phi" in model.named_vars) == ( + "variables_importance" in phi_args or "importance_concentration" in phi_args + ), set(model.named_vars) assert ("beta::psi" in model.named_vars) == (positive_probs_std is not None), set( model.named_vars ) + assert np.isfinite(sum(model.point_logps().values())), model.point_logps() def test_failing_importance(self, dims, input_shape, output_std, input_std): if input_shape[-1] < 2: @@ -163,3 +181,7 @@ def test_failing_mutual_exclusive(self, model: pm.Model): variance_explained=[0.5, 0.5], variables_importance=[1, 1], ) + + def test_limit_case_requires_std_0(self): + # TODO: add test for the limit cases that assertions work as expected + return From 8468667f8d44522bd227a52dd1c39015fcf566ff Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 20:37:27 +0300 Subject: [PATCH 11/19] check limit case requires std 0 --- .../tests/distributions/test_multivariate.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index 5dbac29e7..e4ef4532e 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -182,6 +182,25 @@ def test_failing_mutual_exclusive(self, model: pm.Model): variables_importance=[1, 1], ) - def test_limit_case_requires_std_0(self): - # TODO: add test for the limit cases that assertions work as expected - return + def test_limit_case_requires_std_0(self, model: pm.Model): + model.add_coord("a", range(2)) + with pytest.raises(ValueError, match="Can't have both positive_probs"): + pmx.distributions.R2D2M2CP( + "beta", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 0], + positive_probs_std=[0.3, 0.1], + ) + with pytest.raises(ValueError, match="Can't have both positive_probs"): + pmx.distributions.R2D2M2CP( + "beta", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0.1], + ) From be84689914f8e2f77075637cf3b77e18ed197c8a Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 20:43:06 +0300 Subject: [PATCH 12/19] assert masked variables are created --- .../tests/distributions/test_multivariate.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index e4ef4532e..1444a0841 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -122,7 +122,7 @@ def test_init( assert beta.eval().shape == input_std.shape # r2 rv is only created if r2 std is not None assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) - # phi is only created if variable importances is not None and there is more than one var + # phi is only created if variable importance is not None and there is more than one var assert ("beta::phi" in model.named_vars) == ( "variables_importance" in phi_args or "importance_concentration" in phi_args ), set(model.named_vars) @@ -204,3 +204,34 @@ def test_limit_case_requires_std_0(self, model: pm.Model): positive_probs=[0.5, 1], positive_probs_std=[0.3, 0.1], ) + + def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool): + model.add_coord("a", range(2)) + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + centered=centered, + ) + pmx.distributions.R2D2M2CP( + "beta1", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 0], + positive_probs_std=[0.3, 0], + centered=centered, + ) + if not centered: + assert "beta0::raw::masked" in model.named_vars, model.named_vars + assert "beta1::raw::masked" in model.named_vars, model.named_vars + else: + assert "beta0::masked" in model.named_vars, model.named_vars + assert "beta1::masked" in model.named_vars, model.named_vars + assert "beta1::psi::masked" in model.named_vars + assert "beta0::psi::masked" in model.named_vars From c766659bb6109f89154fa0ef44b4c8c02222c54a Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 21:10:20 +0300 Subject: [PATCH 13/19] add failing test --- pymc_experimental/tests/distributions/test_multivariate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index 1444a0841..565f17f82 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -118,6 +118,7 @@ def test_init( positive_probs=positive_probs, **phi_args ) + assert not np.isnan(beta.eval()).any() assert eps.eval().shape == output_std.shape assert beta.eval().shape == input_std.shape # r2 rv is only created if r2 std is not None From 59c8efb134542491da1edd2dfea91ccd699a3c93 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 1 Jun 2023 21:19:00 +0300 Subject: [PATCH 14/19] fix nans --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 5befe7f87..398a09b9e 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -29,7 +29,7 @@ def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, ps mu = sigma * pi * 2**0.5 if psi_mask is not None: return ( - pt.where(psi_mask, mu, pt.sign(mu) * explained_var**0.5), + pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5), pt.where(psi_mask, sigma, 0), ) else: From 9c65190bb4a5ac04c92df6526ae3af158360b637 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Mon, 5 Jun 2023 19:08:46 +0300 Subject: [PATCH 15/19] fix corner cases --- .../distributions/multivariate/r2d2m2cp.py | 15 ++++++++++++--- .../tests/distributions/test_multivariate.py | 11 +++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 398a09b9e..cdfbe6039 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -70,17 +70,19 @@ def _R2D2M2CP_beta( mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask) if not centered: with pm.Model(name): - if psi_mask is not None: + if psi_mask is not None and psi_mask.any(): r_idx = psi_mask.nonzero() with pm.Model("raw"): raw = pm.Normal("masked", shape=len(r_idx[0])) raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw) raw = pm.Deterministic("raw", raw, dims=dims) + elif psi_mask is not None: + raw = pt.zeros_like(mu_param) else: raw = pm.Normal("raw", dims=dims) beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) else: - if psi_mask is not None: + if psi_mask is not None and psi_mask.any(): r_idx = psi_mask.nonzero() with pm.Model(name): mean = (mu_param / input_sigma)[r_idx] @@ -93,6 +95,8 @@ def _R2D2M2CP_beta( ) beta = pt.set_subtensor(mean, masked) beta = pm.Deterministic(name, beta, dims=dims) + elif psi_mask is not None: + beta = mean else: beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) return beta @@ -121,7 +125,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims): mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0) if np.bitwise_and(~mask, positive_probs_std != 0).any(): raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0") - if (~mask).any(): + if (~mask).any() and mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required r_idx = mask.nonzero() with pm.Model("psi"): psi = pm.Beta( @@ -132,6 +138,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims): ) psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi) psi = pm.Deterministic("psi", psi, dims=dims) + elif (~mask).all(): + # limit case where all the probs are limit case + psi = pt.as_tensor(positive_probs) else: psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims) mask = None diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index 565f17f82..6aa257340 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -236,3 +236,14 @@ def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool): assert "beta1::masked" in model.named_vars, model.named_vars assert "beta1::psi::masked" in model.named_vars assert "beta0::psi::masked" in model.named_vars + + def test_zero_length_rvs_not_created(self, model: pm.Model): + model.add_coord("a", range(2)) + # deterministic case which should not have any new variables + b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a") + assert not model.free_RVs, model.free_RVs + + b = pmx.distributions.R2D2M2CP( + "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a" + ) + assert not model.free_RVs, model.free_RVs From c0252e89489c783e37db4ec318d03389fb5c6625 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Mon, 5 Jun 2023 19:18:08 +0300 Subject: [PATCH 16/19] fix bug with missed variable --- .../distributions/multivariate/r2d2m2cp.py | 8 +++++++- .../tests/distributions/test_multivariate.py | 11 +++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index cdfbe6039..cc38edc27 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -71,18 +71,23 @@ def _R2D2M2CP_beta( if not centered: with pm.Model(name): if psi_mask is not None and psi_mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required r_idx = psi_mask.nonzero() with pm.Model("raw"): raw = pm.Normal("masked", shape=len(r_idx[0])) raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw) raw = pm.Deterministic("raw", raw, dims=dims) elif psi_mask is not None: + # all variables are deterministic raw = pt.zeros_like(mu_param) else: raw = pm.Normal("raw", dims=dims) beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) else: if psi_mask is not None and psi_mask.any(): + # limit case where some probs are not 1 or 0 + # setsubtensor is required r_idx = psi_mask.nonzero() with pm.Model(name): mean = (mu_param / input_sigma)[r_idx] @@ -96,7 +101,8 @@ def _R2D2M2CP_beta( beta = pt.set_subtensor(mean, masked) beta = pm.Deterministic(name, beta, dims=dims) elif psi_mask is not None: - beta = mean + # all variables are deterministic + beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims) else: beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) return beta diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index 6aa257340..a1e92669f 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -50,7 +50,7 @@ def r2(self): def r2_std(self, request): return request.param - @pytest.fixture(params=["true", "false", "limit-1", "limit-0"]) + @pytest.fixture(params=["true", "false", "limit-1", "limit-0", "limit-all"]) def positive_probs(self, input_std, request): if request.param == "true": return np.full_like(input_std, 0.5) @@ -64,6 +64,8 @@ def positive_probs(self, input_std, request): ret = np.full_like(input_std, 0.5) ret[..., 0] = 0 return ret + elif request.param == "limit-all": + return np.full_like(input_std, 0) @pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"]) def positive_probs_std(self, positive_probs, request): @@ -122,14 +124,15 @@ def test_init( assert eps.eval().shape == output_std.shape assert beta.eval().shape == input_std.shape # r2 rv is only created if r2 std is not None + assert "beta" in model.named_vars assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) # phi is only created if variable importance is not None and there is more than one var assert ("beta::phi" in model.named_vars) == ( "variables_importance" in phi_args or "importance_concentration" in phi_args ), set(model.named_vars) - assert ("beta::psi" in model.named_vars) == (positive_probs_std is not None), set( - model.named_vars - ) + assert ("beta::psi" in model.named_vars) == ( + positive_probs_std is not None and positive_probs_std.any() + ), set(model.named_vars) assert np.isfinite(sum(model.point_logps().values())), model.point_logps() def test_failing_importance(self, dims, input_shape, output_std, input_std): From a030ddaab841083e2ac1264e8bec40b35dc1a0df Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 15 Jun 2023 17:40:53 +0300 Subject: [PATCH 17/19] add skipif float32 --- pymc_experimental/tests/distributions/test_multivariate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index a1e92669f..b43b18d35 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -1,5 +1,6 @@ import numpy as np import pymc as pm +import pytensor import pytest import pymc_experimental as pmx @@ -95,6 +96,10 @@ def phi_args(self, request, phi_args_base): phi_args_base["importance_concentration"] = 10 return phi_args_base + @pytest.mark.skipif( + pytensor.config.floatX == "float32", + reason="pytensor.config.floatX == 'float32', https://github.com/pymc-devs/pymc/issues/6779", + ) def test_init( self, dims, @@ -133,7 +138,7 @@ def test_init( assert ("beta::psi" in model.named_vars) == ( positive_probs_std is not None and positive_probs_std.any() ), set(model.named_vars) - assert np.isfinite(sum(model.point_logps().values())), model.point_logps() + assert np.isfinite(sum(model.point_logps().values())) def test_failing_importance(self, dims, input_shape, output_std, input_std): if input_shape[-1] < 2: From ed5d36c998341a97e0bf2254b740273a67c475f8 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 22 Jun 2023 12:08:16 +0000 Subject: [PATCH 18/19] add requirement for dims to be immutable for the prior as it is required for the limit case masking --- .../distributions/multivariate/r2d2m2cp.py | 4 +++- .../tests/distributions/test_multivariate.py | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index cc38edc27..f90f05e7e 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -394,7 +394,9 @@ def R2D2M2CP( *broadcast_dims, dim = dims input_sigma = pt.as_tensor(input_sigma) output_sigma = pt.as_tensor(output_sigma) - with pm.Model(name): + with pm.Model(name) as model: + if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims): + raise ValueError(f"{dims!r} should be constant length imutable dims") if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) phi = _phi( diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index b43b18d35..4dd81b2b3 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -255,3 +255,27 @@ def test_zero_length_rvs_not_created(self, model: pm.Model): "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a" ) assert not model.free_RVs, model.free_RVs + + def test_immutable_dims(self, model: pm.Model): + model.add_coord("a", range(2), mutable=True) + model.add_coord("b", range(2), mutable=False) + with pytest.raises(ValueError, match="should be constant length immutable dims"): + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims="a", + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + ) + with pytest.raises(ValueError, match="should be constant length immutable dims"): + pmx.distributions.R2D2M2CP( + "beta0", + 1, + [1, 1], + dims=("a", "b"), + r2=0.8, + positive_probs=[0.5, 1], + positive_probs_std=[0.3, 0], + ) From 745c71ddffdc70982a969d06fbae6c5b841a7c6a Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 22 Jun 2023 14:49:38 +0000 Subject: [PATCH 19/19] fix error message --- pymc_experimental/distributions/multivariate/r2d2m2cp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index f90f05e7e..e7b43e679 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -396,7 +396,7 @@ def R2D2M2CP( output_sigma = pt.as_tensor(output_sigma) with pm.Model(name) as model: if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims): - raise ValueError(f"{dims!r} should be constant length imutable dims") + raise ValueError(f"{dims!r} should be constant length immutable dims") if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) phi = _phi(