diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 56bd7c5fe3..6513bfd5f4 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -24,6 +24,7 @@ import pytensor.tensor as pt import scipy +from pytensor.graph import node_rewriter from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.raise_op import Assert @@ -39,7 +40,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace from pytensor.tensor.linalg import inv as matrix_inverse -from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal +from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import ( broadcast_params, @@ -77,6 +78,9 @@ ) from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform from pymc.logprob.abstract import _logprob +from pymc.logprob.rewriting import ( + specialization_ir_rewrites_db, +) from pymc.math import kron_diag, kron_dot from pymc.pytensorf import normalize_rng_param from pymc.util import check_dist_not_registered @@ -157,6 +161,13 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): return cov +def _logdet_from_cholesky(chol: TensorVariable) -> tuple[TensorVariable, TensorVariable]: + diag = pt.diagonal(chol, axis1=-2, axis2=-1) + logdet = pt.log(diag).sum(axis=-1) + posdef = pt.all(diag > 0, axis=-1) + return logdet, posdef + + def quaddist_chol(value, mu, cov): """Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma.""" if value.ndim == 0: @@ -167,23 +178,23 @@ def quaddist_chol(value, mu, cov): else: onedim = False - delta = value - mu chol_cov = nan_lower_cholesky(cov) + logdet, posdef = _logdet_from_cholesky(chol_cov) - diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1) - # Check if the covariance matrix is positive definite. - ok = pt.all(diag > 0, axis=-1) - # If not, replace the diagonal. We return -inf later, but - # need to prevent solve_lower from throwing an exception. - chol_cov = pt.switch(ok[..., None, None], chol_cov, 1) + # solve_triangular will raise if there are nans + # (which happens if the cholesky fails) + chol_cov.dprint(print_type=True, depth=1) + posdef.dprint(print_type=True, depth=1) + chol_cov = pt.switch(posdef[..., None, None], chol_cov, 1) + + delta = value - mu delta_trans = solve_lower(chol_cov, delta, b_ndim=1) quaddist = (delta_trans**2).sum(axis=-1) - logdet = pt.log(diag).sum(axis=-1) if onedim: - return quaddist[0], logdet, ok + return quaddist[0], logdet, posdef else: - return quaddist, logdet, ok + return quaddist, logdet, posdef class MvNormal(Continuous): @@ -283,16 +294,80 @@ def logp(value, mu, cov): ------- TensorVariable """ - quaddist, logdet, ok = quaddist_chol(value, mu, cov) + quaddist, logdet, posdef = quaddist_chol(value, mu, cov) k = value.shape[-1].astype("floatX") norm = -0.5 * k * np.log(2 * np.pi) return check_parameters( norm - 0.5 * quaddist - logdet, - ok, - msg="posdef", + posdef, + msg="posdef covariance", ) +class PrecisionMvNormalRV(SymbolicRandomVariable): + r"""A specialized multivariate normal random variable defined in terms of precision. + + This class is introduced during specialization logprob rewrites, and not meant to be used directly. + """ + + name = "precision_multivariate_normal" + extended_signature = "[rng],[size],(n),(n,n)->(n)" + _print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}") + + @classmethod + def rv_op(cls, mean, tau, *, rng=None, size=None): + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + cov = pt.linalg.inv(tau) + next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs + return cls( + inputs=[rng, size, mean, tau], + outputs=[next_rng, draws], + )(rng, size, mean, tau) + + +@_logprob.register +def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): + [value] = value + k = value.shape[-1].astype("floatX") + + delta = value - mean + quadratic_form = delta.T @ tau @ delta + logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau)) + logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet + + return check_parameters( + logp, + posdef, + msg="posdef precision", + ) + + +@node_rewriter(tracks=[MvNormalRV]) +def mv_normal_to_precision_mv_normal(fgraph, node): + """Replaces MvNormal(mu, inv(tau)) -> PrecisionMvNormal(mu, tau) + + This is introduced in logprob rewrites to provide a more efficient logp for a MvNormal + that is defined by a precision matrix. + + Note: This won't be introduced when calling `pm.logp` as that will dispatch directly + without triggering the logprob rewrites. + """ + + rng, size, mu, cov = node.inputs + if cov.owner and cov.owner.op == matrix_inverse: + tau = cov.owner.inputs[0] + return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs + return None + + +specialization_ir_rewrites_db.register( + mv_normal_to_precision_mv_normal.__name__, + mv_normal_to_precision_mv_normal, + "basic", +) + + class MvStudentTRV(RandomVariable): name = "multivariate_studentt" signature = "(),(n),(n,n)->(n)" diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index eb3ebd6899..0994a77e68 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -60,6 +60,7 @@ out2in, ) from pytensor.graph.rewriting.db import ( + EquilibriumDB, LocalGroupDB, RewriteDatabase, RewriteDatabaseQuery, @@ -379,6 +380,14 @@ def incsubtensor_rv_replace(fgraph, node): measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic") +# These rewrites are used to introduce specalized operations with better logprob graphs +specialization_ir_rewrites_db = EquilibriumDB() +specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db" +logprob_rewrites_db.register( + "specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic" +) + + logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") # Rewrites that remove IR Ops diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 2848fa2989..74199fa107 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -26,11 +26,13 @@ from pytensor import tensor as pt from pytensor.tensor import TensorVariable from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.slinalg import Cholesky import pymc as pm +from pymc import Model from pymc.distributions.multivariate import ( MultivariateIntervalTransform, _LKJCholeskyCov, @@ -2468,3 +2470,30 @@ def test_mvstudentt_mu_convenience(): x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1, 1)), scale=np.full((2, 3, 3), np.eye(3))) mu = x.owner.inputs[3] np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3))) + + +def test_precision_mv_normal_optimization(): + rng = np.random.default_rng(sum(map(ord, "be precise"))) + + n = 30 + L = rng.uniform(low=0.1, high=1.0, size=(n, n)) + Sigma_test = L @ L.T + mu_test = np.zeros(n) + Q_test = np.linalg.inv(Sigma_test) + y_test = rng.normal(size=n) + + with Model() as m: + Q = pm.Flat("Q", shape=(n, n)) + y = pm.MvNormal("y", mu=mu_test, tau=Q) + + y_logp_fn = m.compile_logp(vars=[y]).f + + # Check we don't have any MatrixInverses in the logp + assert not any( + node for node in y_logp_fn.maker.fgraph.apply_nodes if isinstance(node.op, MatrixInverse) + ) + + np.testing.assert_allclose( + y_logp_fn(y=y_test, Q=Q_test), + st.multivariate_normal.logpdf(y_test, mu_test, cov=Sigma_test), + )