Skip to content

Implement specialized MvNormal density based on precision matrix #7345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 89 additions & 14 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytensor.tensor as pt
import scipy

from pytensor.graph import node_rewriter
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.raise_op import Assert
Expand All @@ -39,7 +40,7 @@
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
from pytensor.tensor.linalg import inv as matrix_inverse
from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal
from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import (
broadcast_params,
Expand Down Expand Up @@ -77,6 +78,9 @@
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.logprob.abstract import _logprob
from pymc.logprob.rewriting import (
specialization_ir_rewrites_db,
)
from pymc.math import kron_diag, kron_dot
from pymc.pytensorf import normalize_rng_param
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -157,6 +161,13 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
return cov


def _logdet_from_cholesky(chol: TensorVariable) -> tuple[TensorVariable, TensorVariable]:
diag = pt.diagonal(chol, axis1=-2, axis2=-1)
logdet = pt.log(diag).sum(axis=-1)
posdef = pt.all(diag > 0, axis=-1)
return logdet, posdef


def quaddist_chol(value, mu, cov):
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
if value.ndim == 0:
Expand All @@ -167,23 +178,23 @@ def quaddist_chol(value, mu, cov):
else:
onedim = False

delta = value - mu
chol_cov = nan_lower_cholesky(cov)
logdet, posdef = _logdet_from_cholesky(chol_cov)

diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
# Check if the covariance matrix is positive definite.
ok = pt.all(diag > 0, axis=-1)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
# solve_triangular will raise if there are nans
# (which happens if the cholesky fails)
chol_cov.dprint(print_type=True, depth=1)
posdef.dprint(print_type=True, depth=1)
chol_cov = pt.switch(posdef[..., None, None], chol_cov, 1)

delta = value - mu
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = pt.log(diag).sum(axis=-1)

if onedim:
return quaddist[0], logdet, ok
return quaddist[0], logdet, posdef
else:
return quaddist, logdet, ok
return quaddist, logdet, posdef


class MvNormal(Continuous):
Expand Down Expand Up @@ -283,16 +294,80 @@ def logp(value, mu, cov):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
quaddist, logdet, posdef = quaddist_chol(value, mu, cov)
k = value.shape[-1].astype("floatX")
norm = -0.5 * k * np.log(2 * np.pi)
return check_parameters(
norm - 0.5 * quaddist - logdet,
ok,
msg="posdef",
posdef,
msg="posdef covariance",
)


class PrecisionMvNormalRV(SymbolicRandomVariable):
r"""A specialized multivariate normal random variable defined in terms of precision.

This class is introduced during specialization logprob rewrites, and not meant to be used directly.
"""

name = "precision_multivariate_normal"
extended_signature = "[rng],[size],(n),(n,n)->(n)"
_print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}")

@classmethod
def rv_op(cls, mean, tau, *, rng=None, size=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
cov = pt.linalg.inv(tau)
next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs
return cls(
inputs=[rng, size, mean, tau],
outputs=[next_rng, draws],
)(rng, size, mean, tau)


@_logprob.register
def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs):
[value] = value
k = value.shape[-1].astype("floatX")

delta = value - mean
quadratic_form = delta.T @ tau @ delta
logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau))
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet

return check_parameters(
logp,
posdef,
msg="posdef precision",
)


@node_rewriter(tracks=[MvNormalRV])
def mv_normal_to_precision_mv_normal(fgraph, node):
"""Replaces MvNormal(mu, inv(tau)) -> PrecisionMvNormal(mu, tau)

This is introduced in logprob rewrites to provide a more efficient logp for a MvNormal
that is defined by a precision matrix.

Note: This won't be introduced when calling `pm.logp` as that will dispatch directly
without triggering the logprob rewrites.
"""

rng, size, mu, cov = node.inputs
if cov.owner and cov.owner.op == matrix_inverse:
tau = cov.owner.inputs[0]
return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
return None


specialization_ir_rewrites_db.register(
mv_normal_to_precision_mv_normal.__name__,
mv_normal_to_precision_mv_normal,
"basic",
)


class MvStudentTRV(RandomVariable):
name = "multivariate_studentt"
signature = "(),(n),(n,n)->(n)"
Expand Down
9 changes: 9 additions & 0 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
out2in,
)
from pytensor.graph.rewriting.db import (
EquilibriumDB,
LocalGroupDB,
RewriteDatabase,
RewriteDatabaseQuery,
Expand Down Expand Up @@ -379,6 +380,14 @@ def incsubtensor_rv_replace(fgraph, node):
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic")

# These rewrites are used to introduce specalized operations with better logprob graphs
specialization_ir_rewrites_db = EquilibriumDB()
specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db"
logprob_rewrites_db.register(
"specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic"
)


logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")

# Rewrites that remove IR Ops
Expand Down
29 changes: 29 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from pytensor import tensor as pt
from pytensor.tensor import TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.slinalg import Cholesky

import pymc as pm

from pymc import Model
from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
Expand Down Expand Up @@ -2468,3 +2470,30 @@ def test_mvstudentt_mu_convenience():
x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1, 1)), scale=np.full((2, 3, 3), np.eye(3)))
mu = x.owner.inputs[3]
np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3)))


def test_precision_mv_normal_optimization():
rng = np.random.default_rng(sum(map(ord, "be precise")))

n = 30
L = rng.uniform(low=0.1, high=1.0, size=(n, n))
Sigma_test = L @ L.T
mu_test = np.zeros(n)
Q_test = np.linalg.inv(Sigma_test)
y_test = rng.normal(size=n)

with Model() as m:
Q = pm.Flat("Q", shape=(n, n))
y = pm.MvNormal("y", mu=mu_test, tau=Q)

y_logp_fn = m.compile_logp(vars=[y]).f

# Check we don't have any MatrixInverses in the logp
assert not any(
node for node in y_logp_fn.maker.fgraph.apply_nodes if isinstance(node.op, MatrixInverse)
)

np.testing.assert_allclose(
y_logp_fn(y=y_test, Q=Q_test),
st.multivariate_normal.logpdf(y_test, mu_test, cov=Sigma_test),
)
Loading