Skip to content

Commit 4eeac62

Browse files
committed
Convert _LKJCholeskyCov into a SymbolicRandomVariable
This is necessary because the distribution is in fact a "Factory distribution" that needs to be able to resize `sd_dist` dynamically in order to work correctly.
1 parent a74d193 commit 4eeac62

File tree

2 files changed

+138
-90
lines changed

2 files changed

+138
-90
lines changed

pymc/distributions/multivariate.py

Lines changed: 115 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import scipy
2626

27+
from aeppl.logprob import _logprob
2728
from aesara.graph.basic import Apply, Constant, Variable
2829
from aesara.graph.op import Op
2930
from aesara.raise_op import Assert
@@ -32,7 +33,7 @@
3233
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
3334
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
3435
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
35-
from aesara.tensor.random.utils import broadcast_params, normalize_size_param
36+
from aesara.tensor.random.utils import broadcast_params
3637
from aesara.tensor.slinalg import Cholesky, SolveTriangular
3738
from aesara.tensor.type import TensorType
3839
from scipy import linalg, stats
@@ -49,9 +50,17 @@
4950
logpow,
5051
multigammaln,
5152
)
52-
from pymc.distributions.distribution import Continuous, Discrete, moment
53+
from pymc.distributions.distribution import (
54+
Continuous,
55+
Discrete,
56+
Distribution,
57+
SymbolicRandomVariable,
58+
_moment,
59+
moment,
60+
)
5361
from pymc.distributions.logprob import ignore_logprob
5462
from pymc.distributions.shape_utils import (
63+
_change_dist_size,
5564
broadcast_dist_samples_to,
5665
change_dist_size,
5766
rv_size_is_none,
@@ -1097,12 +1106,12 @@ def _lkj_normalizing_constant(eta, n):
10971106
return result
10981107

10991108

1100-
class _LKJCholeskyCovRV(RandomVariable):
1101-
name = "_lkjcholeskycov"
1109+
class _LKJCholeskyCovBaseRV(RandomVariable):
1110+
name = "_lkjcholeskycovbase"
11021111
ndim_supp = 1
11031112
ndims_params = [0, 0, 1]
11041113
dtype = "floatX"
1105-
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")
1114+
_print_name = ("_lkjcholeskycovbase", "\\operatorname{_lkjcholeskycovbase}")
11061115

11071116
def make_node(self, rng, size, dtype, n, eta, D):
11081117
n = at.as_tensor_variable(n)
@@ -1115,35 +1124,19 @@ def make_node(self, rng, size, dtype, n, eta, D):
11151124

11161125
D = at.as_tensor_variable(D)
11171126

1118-
# We resize the sd_dist `D` automatically so that it has (size x n) independent
1119-
# draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
1120-
# random and logp methods equivalent, as the latter also assumes a unique value
1121-
# for each diagonal element.
1122-
# Since `eta` and `n` are forced to be scalars we don't need to worry about
1123-
# implied batched dimensions for the time being.
1124-
size = normalize_size_param(size)
1125-
if D.owner.op.ndim_supp == 0:
1126-
D = change_dist_size(D, at.concatenate((size, (n,))))
1127-
else:
1128-
# The support shape must be `n` but we have no way of controlling it
1129-
D = change_dist_size(D, size)
1130-
11311127
return super().make_node(rng, size, dtype, n, eta, D)
11321128

1133-
def _infer_shape(self, size, dist_params, param_shapes=None):
1129+
def _supp_shape_from_params(self, dist_params, param_shapes):
11341130
n = dist_params[0]
1135-
dist_shape = tuple(size) + ((n * (n + 1)) // 2,)
1136-
return dist_shape
1131+
return ((n * (n + 1)) // 2,)
11371132

11381133
def rng_fn(self, rng, n, eta, D, size):
11391134
# We flatten the size to make operations easier, and then rebuild it
11401135
if size is None:
1141-
flat_size = 1
1142-
else:
1143-
flat_size = np.prod(size)
1144-
1145-
C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
1136+
size = D.shape[:-1]
1137+
flat_size = np.prod(size).astype(int)
11461138

1139+
C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
11471140
D = D.reshape(flat_size, n)
11481141
C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
11491142

@@ -1159,23 +1152,30 @@ def rng_fn(self, rng, n, eta, D, size):
11591152
return samples
11601153

11611154

1162-
_ljk_cholesky_cov = _LKJCholeskyCovRV()
1155+
_ljk_cholesky_cov_base = _LKJCholeskyCovBaseRV()
11631156

11641157

1165-
class _LKJCholeskyCov(Continuous):
1158+
# _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
1159+
# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
1160+
class _LKJCholeskyCovRV(SymbolicRandomVariable):
1161+
default_output = 1
1162+
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")
1163+
1164+
def update(self, node):
1165+
return {node.inputs[0]: node.outputs[0]}
1166+
1167+
1168+
class _LKJCholeskyCov(Distribution):
11661169
r"""Underlying class for covariance matrix with LKJ distributed correlations.
11671170
See docs for LKJCholeskyCov function for more details on how to use it in models.
11681171
"""
1169-
rv_op = _ljk_cholesky_cov
11701172

1171-
def __new__(cls, name, eta, n, sd_dist, **kwargs):
1172-
check_dist_not_registered(sd_dist)
1173-
return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
1173+
rv_type = _LKJCholeskyCovRV
11741174

11751175
@classmethod
1176-
def dist(cls, eta, n, sd_dist, **kwargs):
1177-
eta = at.as_tensor_variable(floatX(eta))
1176+
def dist(cls, n, eta, sd_dist, **kwargs):
11781177
n = at.as_tensor_variable(intX(n))
1178+
eta = at.as_tensor_variable(floatX(eta))
11791179

11801180
if not (
11811181
isinstance(sd_dist, Variable)
@@ -1185,75 +1185,105 @@ def dist(cls, eta, n, sd_dist, **kwargs):
11851185
):
11861186
raise TypeError("sd_dist must be a scalar or vector distribution variable")
11871187

1188+
check_dist_not_registered(sd_dist)
11881189
# sd_dist is part of the generative graph, but should be completely ignored
11891190
# by the logp graph, since the LKJ logp explicitly includes these terms.
1190-
# TODO: Things could be simplified a bit if we managed to extract the
1191-
# sd_dist prior components from the logp expression.
11921191
sd_dist = ignore_logprob(sd_dist)
1193-
11941192
return super().dist([n, eta, sd_dist], **kwargs)
11951193

1196-
def moment(rv, size, n, eta, sd_dists):
1197-
diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")
1198-
moment = at.zeros_like(rv)
1199-
moment = at.set_subtensor(moment[..., diag_idxs], 1)
1200-
return moment
1194+
@classmethod
1195+
def rv_op(cls, n, eta, sd_dist, size=None):
1196+
# We resize the sd_dist automatically so that it has (size x n) independent
1197+
# draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the
1198+
# random and logp methods equivalent, as the latter also assumes a unique value
1199+
# for each diagonal element.
1200+
# Since `eta` and `n` are forced to be scalars we don't need to worry about
1201+
# implied batched dimensions from those for the time being.
1202+
if size is None:
1203+
size = sd_dist.shape[:-1]
1204+
shape = tuple(size) + (n,)
1205+
if sd_dist.owner.op.ndim_supp == 0:
1206+
sd_dist = change_dist_size(sd_dist, shape)
1207+
else:
1208+
# The support shape must be `n` but we have no way of controlling it
1209+
sd_dist = change_dist_size(sd_dist, shape[:-1])
12011210

1202-
def logp(value, n, eta, sd_dist):
1203-
"""
1204-
Calculate log-probability of Covariance matrix with LKJ
1205-
distributed correlations at specified value.
1211+
# Create new rng for the _lkjcholeskycov internal RV
1212+
rng = aesara.shared(np.random.default_rng())
12061213

1207-
Parameters
1208-
----------
1209-
value: numeric
1210-
Value for which log-probability is calculated.
1214+
rng_, n_, eta_, sd_dist_ = rng.type(), n.type(), eta.type(), sd_dist.type()
1215+
next_rng_, lkjcov_ = _ljk_cholesky_cov_base(n_, eta_, sd_dist_, rng=rng_).owner.outputs
12111216

1212-
Returns
1213-
-------
1214-
TensorVariable
1215-
"""
1217+
return _LKJCholeskyCovRV(
1218+
inputs=[rng_, n_, eta_, sd_dist_],
1219+
outputs=[next_rng_, lkjcov_],
1220+
ndim_supp=1,
1221+
)(rng, n, eta, sd_dist)
12161222

1217-
if value.ndim > 1:
1218-
raise ValueError("LKJCholeskyCov logp is only implemented for vector values (ndim=1)")
12191223

1220-
diag_idxs = at.cumsum(at.arange(1, n + 1)) - 1
1221-
cumsum = at.cumsum(value**2)
1222-
variance = at.zeros(at.atleast_1d(n))
1223-
variance = at.inc_subtensor(variance[0], value[0] ** 2)
1224-
variance = at.inc_subtensor(variance[1:], cumsum[diag_idxs[1:]] - cumsum[diag_idxs[:-1]])
1225-
sd_vals = at.sqrt(variance)
1224+
@_change_dist_size.register(_LKJCholeskyCovRV)
1225+
def change_LKJCholeksyCovRV_size(op, dist, new_size, expand=False):
1226+
n, eta, sd_dist = dist.owner.inputs[1:]
12261227

1227-
logp_sd = pm.logp(sd_dist, sd_vals).sum()
1228-
corr_diag = value[diag_idxs] / sd_vals
1228+
if expand:
1229+
old_size = sd_dist.shape[:-1]
1230+
new_size = tuple(new_size) + tuple(old_size)
12291231

1230-
logp_lkj = (2 * eta - 3 + n - at.arange(n)) * at.log(corr_diag)
1231-
logp_lkj = at.sum(logp_lkj)
1232+
return _LKJCholeskyCov.rv_op(n, eta, sd_dist, size=new_size)
12321233

1233-
# Compute the log det jacobian of the second transformation
1234-
# described in the docstring.
1235-
idx = at.arange(n)
1236-
det_invjac = at.log(corr_diag) - idx * at.log(sd_vals)
1237-
det_invjac = det_invjac.sum()
12381234

1239-
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1240-
if not isinstance(n, Constant):
1241-
raise NotImplementedError("logp only implemented for constant `n`")
1242-
n = int(n.data)
1235+
@_moment.register(_LKJCholeskyCovRV)
1236+
def _LKJCholeksyCovRV_moment(op, rv, rng, n, eta, sd_dist):
1237+
diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")
1238+
moment = at.zeros_like(rv)
1239+
moment = at.set_subtensor(moment[..., diag_idxs], 1)
1240+
return moment
12431241

1244-
if not isinstance(eta, Constant):
1245-
raise NotImplementedError("logp only implemented for constant `eta`")
1246-
eta = float(eta.data)
12471242

1248-
norm = _lkj_normalizing_constant(eta, n)
1243+
@_default_transform.register(_LKJCholeskyCovRV)
1244+
def _LKJCholeksyCovRV_default_transform(op, rv):
1245+
_, n, _, _ = rv.owner.inputs
1246+
return transforms.CholeskyCovPacked(n)
12491247

1250-
return norm + logp_lkj + logp_sd + det_invjac
12511248

1249+
@_logprob.register(_LKJCholeskyCovRV)
1250+
def _LKJCholeksyCovRV_logp(op, values, rng, n, eta, sd_dist, **kwargs):
1251+
(value,) = values
12521252

1253-
@_default_transform.register(_LKJCholeskyCov)
1254-
def lkjcholeskycov_default_transform(op, rv):
1255-
_, _, _, n, _, _ = rv.owner.inputs
1256-
return transforms.CholeskyCovPacked(n)
1253+
if value.ndim > 1:
1254+
raise ValueError("_LKJCholeskyCov logp is only implemented for vector values (ndim=1)")
1255+
1256+
diag_idxs = at.cumsum(at.arange(1, n + 1)) - 1
1257+
cumsum = at.cumsum(value**2)
1258+
variance = at.zeros(at.atleast_1d(n))
1259+
variance = at.inc_subtensor(variance[0], value[0] ** 2)
1260+
variance = at.inc_subtensor(variance[1:], cumsum[diag_idxs[1:]] - cumsum[diag_idxs[:-1]])
1261+
sd_vals = at.sqrt(variance)
1262+
1263+
logp_sd = pm.logp(sd_dist, sd_vals).sum()
1264+
corr_diag = value[diag_idxs] / sd_vals
1265+
1266+
logp_lkj = (2 * eta - 3 + n - at.arange(n)) * at.log(corr_diag)
1267+
logp_lkj = at.sum(logp_lkj)
1268+
1269+
# Compute the log det jacobian of the second transformation
1270+
# described in the docstring.
1271+
idx = at.arange(n)
1272+
det_invjac = at.log(corr_diag) - idx * at.log(sd_vals)
1273+
det_invjac = det_invjac.sum()
1274+
1275+
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1276+
if not isinstance(n, Constant):
1277+
raise NotImplementedError("logp only implemented for constant `n`")
1278+
n = int(n.data)
1279+
1280+
if not isinstance(eta, Constant):
1281+
raise NotImplementedError("logp only implemented for constant `eta`")
1282+
eta = float(eta.data)
1283+
1284+
norm = _lkj_normalizing_constant(eta, n)
1285+
1286+
return norm + logp_lkj + logp_sd + det_invjac
12571287

12581288

12591289
class LKJCholeskyCov:
@@ -1462,7 +1492,7 @@ def rng_fn(cls, rng, n, eta, size):
14621492
else:
14631493
flat_size = np.prod(size)
14641494

1465-
C = cls._random_corr_matrix(rng, n, eta, flat_size)
1495+
C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
14661496

14671497
triu_idx = np.triu_indices(n, k=1)
14681498
samples = C[..., triu_idx[0], triu_idx[1]]

pymc/tests/distributions/test_multivariate.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838
_OrderedMultinomial,
3939
quaddist_matrix,
4040
)
41-
from pymc.distributions.shape_utils import to_tuple
41+
from pymc.distributions.shape_utils import change_dist_size, to_tuple
4242
from pymc.math import kronecker
43+
from pymc.sampling import draw
4344
from pymc.tests.distributions.util import (
4445
BaseTestDistributionRandom,
4546
Domain,
@@ -881,6 +882,18 @@ def test_sd_dist_automatically_resized(self, sd_dist, size, shape):
881882
# LKJCov has support shape `(n * (n+1)) // 2`
882883
assert x.eval().shape == (10, 6)
883884

885+
def test_change_dist_size(self):
886+
x1 = pm.LKJCholeskyCov.dist(
887+
n=3, eta=1, sd_dist=pm.Dirichlet.dist(np.ones(3)), size=(5,), compute_corr=False
888+
)
889+
x2 = change_dist_size(x1, new_size=(10, 3), expand=False)
890+
x3 = change_dist_size(x2, new_size=(3,), expand=True)
891+
892+
draw_x1, draw_x2, draw_x3 = pm.draw([x1, x2, x3])
893+
assert draw_x1.shape == (5, 6)
894+
assert draw_x2.shape == (10, 3, 6)
895+
assert draw_x3.shape == (3, 10, 3, 6)
896+
884897

885898
# Used for MvStudentT moment test
886899
rand1d = np.random.rand(2)
@@ -1783,7 +1796,7 @@ def ref_rand(size, n, eta):
17831796

17841797
class TestLKJCholeskyCov(BaseTestDistributionRandom):
17851798
pymc_dist = _LKJCholeskyCov
1786-
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
1799+
pymc_dist_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
17871800
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
17881801
size = None
17891802

@@ -1803,6 +1816,11 @@ class TestLKJCholeskyCov(BaseTestDistributionRandom):
18031816
"check_draws_match_expected",
18041817
]
18051818

1819+
def _instantiate_pymc_rv(self, dist_params=None):
1820+
# RNG cannot be passed through the PyMC class
1821+
params = dist_params if dist_params else self.pymc_dist_params
1822+
self.pymc_rv = self.pymc_dist.dist(**params, size=self.size)
1823+
18061824
def check_rv_size(self):
18071825
for size, expected in zip(self.sizes_to_check, self.sizes_expected):
18081826
sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), 3))
@@ -1813,9 +1831,9 @@ def check_rv_size(self):
18131831

18141832
def check_draws_match_expected(self):
18151833
# TODO: Find better comparison:
1816-
rng = aesara.shared(self.get_random_state(reset=True))
1817-
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.DiracDelta.dist([0.5, 2.0]), rng=rng)
1818-
assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)
1834+
rng = self.get_random_state(reset=True)
1835+
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.DiracDelta.dist([0.5, 2.0]))
1836+
assert np.all(np.abs(draw(x, random_seed=rng) - np.array([0.5, 0, 2.0])) < 0.01)
18191837

18201838

18211839
@pytest.mark.parametrize("sparse", [True, False])

0 commit comments

Comments
 (0)