Skip to content

Fix ordering Transformation for batched dimensions #6255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@
"Interval",
"log_exp_m1",
"ordered",
"univariate_ordered",
"multivariate_ordered",
"log",
"sum_to_1",
"univariate_sum_to_1",
"multivariate_sum_to_1",
"circular",
"CholeskyCovPacked",
"Chain",
Expand Down Expand Up @@ -74,6 +78,14 @@ def log_jac_det(self, value, *inputs):
class Ordered(RVTransform):
name = "ordered"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For Ordered transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
x = at.zeros(value.shape)
x = at.inc_subtensor(x[..., 0], value[..., 0])
Expand All @@ -87,7 +99,10 @@ def forward(self, value, *inputs):
return y

def log_jac_det(self, value, *inputs):
return at.sum(value[..., 1:], axis=-1)
if self.ndim_supp == 0:
return at.sum(value[..., 1:], axis=-1, keepdims=True)
else:
return at.sum(value[..., 1:], axis=-1)


class SumTo1(RVTransform):
Expand All @@ -98,6 +113,14 @@ class SumTo1(RVTransform):

name = "sumto1"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For SumTo1 transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp

def backward(self, value, *inputs):
remaining = 1 - at.sum(value[..., :], axis=-1, keepdims=True)
return at.concatenate([value[..., :], remaining], axis=-1)
Expand All @@ -107,7 +130,10 @@ def forward(self, value, *inputs):

def log_jac_det(self, value, *inputs):
y = at.zeros(value.shape)
return at.sum(y, axis=-1)
if self.ndim_supp == 0:
return at.sum(y, axis=-1, keepdims=True)
else:
return at.sum(y, axis=-1)


class CholeskyCovPacked(RVTransform):
Expand Down Expand Up @@ -330,20 +356,46 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""

ordered = Ordered()
univariate_ordered = Ordered(ndim_supp=0)
univariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a univariate random variable."""

multivariate_ordered = Ordered(ndim_supp=1)
multivariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""

# backwards compatibility
ordered = Ordered(ndim_supp=1)
ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a random variable."""
for use in the ``transform`` argument of a random variable.
This instantiation is for backwards compatibility only.
Please use `univariate_ordererd` or `multivariate_ordered` instead."""

log = LogTransform()
log.__doc__ = """
Instantiation of :class:`aeppl.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""

sum_to_1 = SumTo1()
univariate_sum_to_1 = SumTo1(ndim_supp=0)
univariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a univariate random variable."""

multivariate_sum_to_1 = SumTo1(ndim_supp=1)
multivariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""

# backwards compatibility
sum_to_1 = SumTo1(ndim_supp=1)
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""
for use in the ``transform`` argument of a random variable.
This instantiation is for backwards compatibility only.
Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead."""

circular = CircularTransform()
circular.__doc__ = """
Expand Down
157 changes: 139 additions & 18 deletions pymc/tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from typing import Union

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -139,10 +141,18 @@ def test_simplex_accuracy():


def test_sum_to_1():
check_vector_transform(tr.sum_to_1, Simplex(2))
check_vector_transform(tr.sum_to_1, Simplex(4))
check_vector_transform(tr.univariate_sum_to_1, Simplex(2))
check_vector_transform(tr.univariate_sum_to_1, Simplex(4))

check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1])
with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"):
tr.SumTo1(2)

check_jacobian_det(
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
)
check_jacobian_det(
tr.multivariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
)


def test_log():
Expand Down Expand Up @@ -241,28 +251,36 @@ def test_circular():


def test_ordered():
check_vector_transform(tr.ordered, SortedVector(6))
check_vector_transform(tr.univariate_ordered, SortedVector(6))

with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"):
tr.Ordered(2)

check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False)
check_jacobian_det(
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
)
check_jacobian_det(
tr.multivariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
)

vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3))
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_values():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_vector_transform():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_vector_transform(chain_tranf, UnitSortedVector(3))


@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
def test_chain_jacob_det():
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False)


Expand Down Expand Up @@ -327,7 +345,14 @@ def check_vectortransform_elementwise_logp(self, model):
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
# Original distribution is univariate
if x.owner.op.ndim_supp == 0:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
tr_steps = getattr(transform, "transform_list", [transform])
transform_keeps_dim = any(
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps]
)
if transform_keeps_dim:
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
else:
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
# Original distribution is multivariate
else:
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
Expand Down Expand Up @@ -449,7 +474,7 @@ def test_normal_ordered(self):
{"mu": 0.0, "sigma": 1.0},
size=3,
initval=np.asarray([-1.0, 1.0, 4.0]),
transform=tr.ordered,
transform=tr.univariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -467,7 +492,7 @@ def test_half_normal_ordered(self, sigma, size):
{"sigma": sigma},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -479,7 +504,7 @@ def test_exponential_ordered(self, lam, size):
{"lam": lam},
size=size,
initval=initval,
transform=tr.Chain([tr.log, tr.ordered]),
transform=tr.Chain([tr.log, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -501,7 +526,7 @@ def test_beta_ordered(self, a, b, size):
{"alpha": a, "beta": b},
size=size,
initval=initval,
transform=tr.Chain([tr.logodds, tr.ordered]),
transform=tr.Chain([tr.logodds, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -524,7 +549,7 @@ def transform_params(*inputs):
{"lower": lower, "upper": upper},
size=size,
initval=initval,
transform=tr.Chain([interval, tr.ordered]),
transform=tr.Chain([interval, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -536,7 +561,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
{"mu": mu, "kappa": kappa},
size=size,
initval=initval,
transform=tr.Chain([tr.circular, tr.ordered]),
transform=tr.Chain([tr.circular, tr.univariate_ordered]),
)
self.check_vectortransform_elementwise_logp(model)

Expand All @@ -545,7 +570,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
],
)
def test_uniform_other(self, lower, upper, size, transform):
Expand All @@ -569,7 +594,11 @@ def test_uniform_other(self, lower, upper, size, transform):
def test_mvnormal_ordered(self, mu, cov, size, shape):
initval = np.sort(np.random.randn(*shape))
model = self.build_model(
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered
pm.MvNormal,
{"mu": mu, "cov": cov},
size=size,
initval=initval,
transform=tr.multivariate_ordered,
)
self.check_vectortransform_elementwise_logp(model)

Expand Down Expand Up @@ -598,3 +627,95 @@ def test_discrete_trafo():
with pytest.raises(ValueError) as err:
pm.Binomial("a", n=5, p=0.5, transform="log")
err.match("Transformations for discrete distributions")


def test_2d_univariate_ordered():
with pm.Model() as model:
x_1d = pm.Normal(
"x_1d",
mu=[-3, -1, 1, 2],
sigma=1,
size=(4,),
transform=tr.univariate_ordered,
)
x_2d = pm.Normal(
"x_2d",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=tr.univariate_ordered,
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])


def test_2d_multivariate_ordered():
with pm.Model() as model:
x_1d = pm.MvNormal(
"x_1d",
mu=[-1, 1],
cov=np.eye(2),
initval=[-1, 1],
transform=tr.multivariate_ordered,
)
x_2d = pm.MvNormal(
"x_2d",
mu=[-1, 1],
cov=np.eye(2),
size=2,
initval=[[-1, 1], [-1, 1]],
transform=tr.multivariate_ordered,
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))}
)
np.testing.assert_allclose(log_p[0], log_p[1])


def test_2d_univariate_sum_to_1():
with pm.Model() as model:
x_1d = pm.Normal(
"x_1d",
mu=[-3, -1, 1, 2],
sigma=1,
size=(4,),
transform=tr.univariate_sum_to_1,
)
x_2d = pm.Normal(
"x_2d",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=tr.univariate_sum_to_1,
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])


def test_2d_multivariate_sum_to_1():
with pm.Model() as model:
x_1d = pm.MvNormal(
"x_1d",
mu=[-1, 1],
cov=np.eye(2),
transform=tr.multivariate_sum_to_1,
)
x_2d = pm.MvNormal(
"x_2d",
mu=[-1, 1],
cov=np.eye(2),
size=2,
transform=tr.multivariate_sum_to_1,
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))}
)
np.testing.assert_allclose(log_p[0], log_p[1])