diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index ee142b46fe..2f3d7cb626 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -38,8 +38,12 @@ "Interval", "log_exp_m1", "ordered", + "univariate_ordered", + "multivariate_ordered", "log", "sum_to_1", + "univariate_sum_to_1", + "multivariate_sum_to_1", "circular", "CholeskyCovPacked", "Chain", @@ -74,6 +78,14 @@ def log_jac_det(self, value, *inputs): class Ordered(RVTransform): name = "ordered" + def __init__(self, ndim_supp=0): + if ndim_supp > 1: + raise ValueError( + f"For Ordered transformation number of core dimensions" + f"(ndim_supp) must not exceed 1 but is {ndim_supp}" + ) + self.ndim_supp = ndim_supp + def backward(self, value, *inputs): x = at.zeros(value.shape) x = at.inc_subtensor(x[..., 0], value[..., 0]) @@ -87,7 +99,10 @@ def forward(self, value, *inputs): return y def log_jac_det(self, value, *inputs): - return at.sum(value[..., 1:], axis=-1) + if self.ndim_supp == 0: + return at.sum(value[..., 1:], axis=-1, keepdims=True) + else: + return at.sum(value[..., 1:], axis=-1) class SumTo1(RVTransform): @@ -98,6 +113,14 @@ class SumTo1(RVTransform): name = "sumto1" + def __init__(self, ndim_supp=0): + if ndim_supp > 1: + raise ValueError( + f"For SumTo1 transformation number of core dimensions" + f"(ndim_supp) must not exceed 1 but is {ndim_supp}" + ) + self.ndim_supp = ndim_supp + def backward(self, value, *inputs): remaining = 1 - at.sum(value[..., :], axis=-1, keepdims=True) return at.concatenate([value[..., :], remaining], axis=-1) @@ -107,7 +130,10 @@ def forward(self, value, *inputs): def log_jac_det(self, value, *inputs): y = at.zeros(value.shape) - return at.sum(y, axis=-1) + if self.ndim_supp == 0: + return at.sum(y, axis=-1, keepdims=True) + else: + return at.sum(y, axis=-1) class CholeskyCovPacked(RVTransform): @@ -330,20 +356,46 @@ def extend_axis_rev(array, axis): Instantiation of :class:`pymc.distributions.transforms.LogExpM1` for use in the ``transform`` argument of a random variable.""" -ordered = Ordered() +univariate_ordered = Ordered(ndim_supp=0) +univariate_ordered.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.Ordered` +for use in the ``transform`` argument of a univariate random variable.""" + +multivariate_ordered = Ordered(ndim_supp=1) +multivariate_ordered.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.Ordered` +for use in the ``transform`` argument of a multivariate random variable.""" + +# backwards compatibility +ordered = Ordered(ndim_supp=1) ordered.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.Ordered` -for use in the ``transform`` argument of a random variable.""" +for use in the ``transform`` argument of a random variable. +This instantiation is for backwards compatibility only. +Please use `univariate_ordererd` or `multivariate_ordered` instead.""" log = LogTransform() log.__doc__ = """ Instantiation of :class:`aeppl.transforms.LogTransform` for use in the ``transform`` argument of a random variable.""" -sum_to_1 = SumTo1() +univariate_sum_to_1 = SumTo1(ndim_supp=0) +univariate_sum_to_1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.SumTo1` +for use in the ``transform`` argument of a univariate random variable.""" + +multivariate_sum_to_1 = SumTo1(ndim_supp=1) +multivariate_sum_to_1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.SumTo1` +for use in the ``transform`` argument of a multivariate random variable.""" + +# backwards compatibility +sum_to_1 = SumTo1(ndim_supp=1) sum_to_1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a random variable.""" +for use in the ``transform`` argument of a random variable. +This instantiation is for backwards compatibility only. +Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead.""" circular = CircularTransform() circular.__doc__ = """ diff --git a/pymc/tests/distributions/test_transform.py b/pymc/tests/distributions/test_transform.py index 850b285f4d..5bed2a9a48 100644 --- a/pymc/tests/distributions/test_transform.py +++ b/pymc/tests/distributions/test_transform.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import Union + import aesara import aesara.tensor as at import numpy as np @@ -139,10 +141,18 @@ def test_simplex_accuracy(): def test_sum_to_1(): - check_vector_transform(tr.sum_to_1, Simplex(2)) - check_vector_transform(tr.sum_to_1, Simplex(4)) + check_vector_transform(tr.univariate_sum_to_1, Simplex(2)) + check_vector_transform(tr.univariate_sum_to_1, Simplex(4)) - check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]) + with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"): + tr.SumTo1(2) + + check_jacobian_det( + tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1] + ) + check_jacobian_det( + tr.multivariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1] + ) def test_log(): @@ -241,28 +251,36 @@ def test_circular(): def test_ordered(): - check_vector_transform(tr.ordered, SortedVector(6)) + check_vector_transform(tr.univariate_ordered, SortedVector(6)) + + with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"): + tr.Ordered(2) - check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False) + check_jacobian_det( + tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False + ) + check_jacobian_det( + tr.multivariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False + ) - vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3)) + vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3)) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_values(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5)) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_vector_transform(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) check_vector_transform(chain_tranf, UnitSortedVector(3)) @pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") def test_chain_jacob_det(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False) @@ -327,7 +345,14 @@ def check_vectortransform_elementwise_logp(self, model): jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) # Original distribution is univariate if x.owner.op.ndim_supp == 0: - assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) + tr_steps = getattr(transform, "transform_list", [transform]) + transform_keeps_dim = any( + [isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps] + ) + if transform_keeps_dim: + assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim + else: + assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) # Original distribution is multivariate else: assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim @@ -449,7 +474,7 @@ def test_normal_ordered(self): {"mu": 0.0, "sigma": 1.0}, size=3, initval=np.asarray([-1.0, 1.0, 4.0]), - transform=tr.ordered, + transform=tr.univariate_ordered, ) self.check_vectortransform_elementwise_logp(model) @@ -467,7 +492,7 @@ def test_half_normal_ordered(self, sigma, size): {"sigma": sigma}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.ordered]), + transform=tr.Chain([tr.log, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -479,7 +504,7 @@ def test_exponential_ordered(self, lam, size): {"lam": lam}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.ordered]), + transform=tr.Chain([tr.log, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -501,7 +526,7 @@ def test_beta_ordered(self, a, b, size): {"alpha": a, "beta": b}, size=size, initval=initval, - transform=tr.Chain([tr.logodds, tr.ordered]), + transform=tr.Chain([tr.logodds, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -524,7 +549,7 @@ def transform_params(*inputs): {"lower": lower, "upper": upper}, size=size, initval=initval, - transform=tr.Chain([interval, tr.ordered]), + transform=tr.Chain([interval, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -536,7 +561,7 @@ def test_vonmises_ordered(self, mu, kappa, size): {"mu": mu, "kappa": kappa}, size=size, initval=initval, - transform=tr.Chain([tr.circular, tr.ordered]), + transform=tr.Chain([tr.circular, tr.univariate_ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -545,7 +570,7 @@ def test_vonmises_ordered(self, mu, kappa, size): [ (0.0, 1.0, (2,), tr.simplex), (0.5, 5.5, (2, 3), tr.simplex), - (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])), + (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])), ], ) def test_uniform_other(self, lower, upper, size, transform): @@ -569,7 +594,11 @@ def test_uniform_other(self, lower, upper, size, transform): def test_mvnormal_ordered(self, mu, cov, size, shape): initval = np.sort(np.random.randn(*shape)) model = self.build_model( - pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered + pm.MvNormal, + {"mu": mu, "cov": cov}, + size=size, + initval=initval, + transform=tr.multivariate_ordered, ) self.check_vectortransform_elementwise_logp(model) @@ -598,3 +627,95 @@ def test_discrete_trafo(): with pytest.raises(ValueError) as err: pm.Binomial("a", n=5, p=0.5, transform="log") err.match("Transformations for discrete distributions") + + +def test_2d_univariate_ordered(): + with pm.Model() as model: + x_1d = pm.Normal( + "x_1d", + mu=[-3, -1, 1, 2], + sigma=1, + size=(4,), + transform=tr.univariate_ordered, + ) + x_2d = pm.Normal( + "x_2d", + mu=[-3, -1, 1, 2], + sigma=1, + size=(10, 4), + transform=tr.univariate_ordered, + ) + + log_p = model.compile_logp(sum=False)( + {"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))} + ) + np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) + + +def test_2d_multivariate_ordered(): + with pm.Model() as model: + x_1d = pm.MvNormal( + "x_1d", + mu=[-1, 1], + cov=np.eye(2), + initval=[-1, 1], + transform=tr.multivariate_ordered, + ) + x_2d = pm.MvNormal( + "x_2d", + mu=[-1, 1], + cov=np.eye(2), + size=2, + initval=[[-1, 1], [-1, 1]], + transform=tr.multivariate_ordered, + ) + + log_p = model.compile_logp(sum=False)( + {"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))} + ) + np.testing.assert_allclose(log_p[0], log_p[1]) + + +def test_2d_univariate_sum_to_1(): + with pm.Model() as model: + x_1d = pm.Normal( + "x_1d", + mu=[-3, -1, 1, 2], + sigma=1, + size=(4,), + transform=tr.univariate_sum_to_1, + ) + x_2d = pm.Normal( + "x_2d", + mu=[-3, -1, 1, 2], + sigma=1, + size=(10, 4), + transform=tr.univariate_sum_to_1, + ) + + log_p = model.compile_logp(sum=False)( + {"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))} + ) + np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) + + +def test_2d_multivariate_sum_to_1(): + with pm.Model() as model: + x_1d = pm.MvNormal( + "x_1d", + mu=[-1, 1], + cov=np.eye(2), + transform=tr.multivariate_sum_to_1, + ) + x_2d = pm.MvNormal( + "x_2d", + mu=[-1, 1], + cov=np.eye(2), + size=2, + transform=tr.multivariate_sum_to_1, + ) + + log_p = model.compile_logp(sum=False)( + {"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))} + ) + np.testing.assert_allclose(log_p[0], log_p[1])