Skip to content

Allow batched parameters in MvNormal and MvStudentT distributions #6897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 41 additions & 71 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,37 @@ def simplex_cont_transform(op, rv):


def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
if chol is not None and not lower:
chol = chol.T

if len([i for i in [tau, cov, chol] if i is not None]) != 1:
raise ValueError("Incompatible parameterization. Specify exactly one of tau, cov, or chol.")

if cov is not None:
cov = pt.as_tensor_variable(cov)
if cov.ndim != 2:
raise ValueError("cov must be two dimensional.")
if cov.ndim < 2:
raise ValueError("cov must be at least two dimensional.")
elif tau is not None:
tau = pt.as_tensor_variable(tau)
if tau.ndim != 2:
raise ValueError("tau must be two dimensional.")
# TODO: What's the correct order/approach (in the non-square case)?
# `pytensor.tensor.nlinalg.tensorinv`?
if tau.ndim < 2:
raise ValueError("tau must be at least two dimensional.")
cov = matrix_inverse(tau)
else:
# TODO: What's the correct order/approach (in the non-square case)?
chol = pt.as_tensor_variable(chol)
if chol.ndim != 2:
raise ValueError("chol must be two dimensional.")
if chol.ndim < 2:
raise ValueError("chol must be at least two dimensional.")

if not lower:
chol = pt.swapaxes(chol, -1, -2)

# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
chol.tag.lower_triangular = True
cov = chol.dot(chol.T)
cov = pt.matmul(chol, pt.swapaxes(chol, -1, -2))

return cov


def quaddist_parse(value, mu, cov, mat_type="cov"):
def quaddist_chol(value, mu, cov):
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
if value.ndim > 2 or value.ndim == 0:
raise ValueError("Invalid dimension for value: %s" % value.ndim)
if value.ndim == 0:
raise ValueError("Value can't be a scalar")
if value.ndim == 1:
onedim = True
value = value[None, :]
Expand All @@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):

delta = value - mu
chol_cov = nan_lower_cholesky(cov)
if mat_type != "tau":
dist, logdet, ok = quaddist_chol(delta, chol_cov)
else:
dist, logdet, ok = quaddist_tau(delta, chol_cov)
if onedim:
return dist[0], logdet, ok

return dist, logdet, ok


def quaddist_chol(delta, chol_mat):
diag = pt.diag(chol_mat)
diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
# Check if the covariance matrix is positive definite.
ok = pt.all(diag > 0)
ok = pt.all(diag > 0, axis=-1)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_cov = pt.switch(ok, chol_mat, 1)

delta_trans = solve_lower(chol_cov, delta.T).T
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = pt.sum(pt.log(diag))
return quaddist, logdet, ok


def quaddist_tau(delta, chol_mat):
diag = pt.nlinalg.diag(chol_mat)
# Check if the precision matrix is positive definite.
ok = pt.all(diag > 0)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_tau = pt.switch(ok, chol_mat, 1)
logdet = pt.log(diag).sum(axis=-1)

delta_trans = pt.dot(delta, chol_tau)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = -pt.sum(pt.log(diag))
return quaddist, logdet, ok
if onedim:
return quaddist[0], logdet, ok
else:
return quaddist, logdet, ok


class MvNormal(Continuous):
Expand Down Expand Up @@ -262,14 +238,15 @@ class MvNormal(Continuous):
rv_op = multivariate_normal

@classmethod
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
mu = pt.as_tensor_variable(mu)
cov = quaddist_matrix(cov, chol, tau, lower)
# PyTensor is stricter about the shape of mu, than PyMC used to be
mu = pt.broadcast_arrays(mu, cov[..., -1])[0]
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
return super().dist([mu, cov], **kwargs)

def moment(rv, size, mu, cov):
# mu is broadcasted to the potential length of cov in `dist`
moment = mu
if not rv_size_is_none(size):
moment_size = pt.concatenate([size, [mu.shape[-1]]])
Expand All @@ -290,7 +267,7 @@ def logp(value, mu, cov):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
k = floatX(value.shape[-1])
norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))
return check_parameters(
Expand All @@ -307,22 +284,6 @@ class MvStudentTRV(RandomVariable):
dtype = "floatX"
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")

def make_node(self, rng, size, dtype, nu, mu, cov):
nu = pt.as_tensor_variable(nu)
if not nu.ndim == 0:
raise ValueError("nu must be a scalar (ndim=0).")

return super().make_node(rng, size, dtype, nu, mu, cov)

def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype

if mu is None:
mu = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(nu, mu, cov, size=size, **kwargs)

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
Expand All @@ -333,14 +294,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):

@classmethod
def rng_fn(cls, rng, nu, mu, cov, size):
if size is None:
# When size is implicit, we need to broadcast parameters correctly,
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
# nu broadcasts mu and cov
if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2):
_, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)
# nu is broadcasted by either mu or cov
elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2):
nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)

mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)

# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]

if size:
mu = np.broadcast_to(mu, size + (mu.shape[-1],))

return (mv_samples / chi2_samples) + mu


Expand Down Expand Up @@ -390,7 +358,7 @@ class MvStudentT(Continuous):
rv_op = mv_studentt

@classmethod
def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, nu, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs):
cov = kwargs.pop("cov", None)
if cov is not None:
warnings.warn(
Expand All @@ -407,11 +375,13 @@ def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=Tr
mu = pt.as_tensor_variable(floatX(mu))
scale = quaddist_matrix(scale, chol, tau, lower)
# PyTensor is stricter about the shape of mu, than PyMC used to be
mu = pt.broadcast_arrays(mu, scale[..., -1])[0]
mu, _ = pt.broadcast_arrays(mu, scale[..., -1])

return super().dist([nu, mu, scale], **kwargs)

def moment(rv, size, nu, mu, scale):
# mu is broadcasted to the potential length of scale in `dist`
mu, _ = pt.random.utils.broadcast_params([mu, nu], ndims_params=[1, 0])
moment = mu
if not rv_size_is_none(size):
moment_size = pt.concatenate([size, [mu.shape[-1]]])
Expand All @@ -432,7 +402,7 @@ def logp(value, nu, mu, scale):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, scale)
quaddist, logdet, ok = quaddist_chol(value, mu, scale)
k = floatX(value.shape[-1])

norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * pt.log(nu * np.pi)
Expand Down
27 changes: 12 additions & 15 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,18 @@ def test_list_multivariate_components_deterministic_weights(self, weights, compo
assert not repetitions

# Test logp
# MvNormal logp is currently limited to 2d values
expectation = pytest.raises(ValueError) if mix_eval.ndim > 2 else does_not_raise()
with expectation:
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)

def test_component_choice_random(self):
"""Test that mixture choices change over evaluations"""
Expand Down
Loading