From 8d279fc48655291fe23574f3b143d46b73a47fd4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 1 Mar 2025 02:15:30 +0800 Subject: [PATCH 1/3] Basic implementation --- pymc/distributions/multivariate.py | 31 +++++++++--------- pymc/distributions/transforms.py | 50 ++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 32f9e30f06..8d2badc82c 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1510,7 +1510,7 @@ def helper_deterministics(cls, n, packed_chol): class LKJCorrRV(RandomVariable): name = "lkjcorr" - signature = "(),()->(n)" + signature = "(),()->(n,n)" dtype = "floatX" _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") @@ -1527,8 +1527,8 @@ def make_node(self, rng, size, n, eta): def _supp_shape_from_params(self, dist_params, **kwargs): n = dist_params[0].squeeze() - dist_shape = ((n * (n - 1)) // 2,) - return dist_shape + # dist_shape = ((n * (n - 1)) // 2,) + return (n, n) @classmethod def rng_fn(cls, rng, n, eta, size): @@ -1609,23 +1609,26 @@ def logp(value, n, eta): ------- TensorVariable """ - if value.ndim > 1: - raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)") - - # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant - # n (or else find a different expression) + # if value.ndim > 1: + # raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)") + # try: n = int(get_underlying_scalar_constant_value(n)) except NotScalarConstantError: raise NotImplementedError("logp only implemented for constant `n`") - shape = n * (n - 1) // 2 - tri_index = np.zeros((n, n), dtype="int32") - tri_index[np.triu_indices(n, k=1)] = np.arange(shape) - tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) + # shape = n * (n - 1) // 2 + # tri_index = np.zeros((n, n), dtype="int32") + # tri_index[np.triu_indices(n, k=1)] = np.arange(shape) + # tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) + + # value = pt.take(value, tri_index) + # value = pt.fill_diagonal(value, 1) - value = pt.take(value, tri_index) - value = pt.fill_diagonal(value, 1) + # print(n, type(n)) + # print(value.type.shape) + # value = value @ value.T + # print(value.type.shape) # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants try: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index c8ca8d0554..52c2ce88fe 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -164,6 +164,56 @@ def log_jac_det(self, value, *inputs): return pt.sum(value[..., self.diag_idxs], axis=-1) +class CholeskyCorr(Transform): + """Get a Cholesky Corr from a packed vector.""" + + name = "cholesky-corr-packed" + + def __init__(self, n): + """Create a CholeskyCorrPack object. + + Parameters + ---------- + n: int + Number of diagonal entries in the LKJCholeskyCov distribution + """ + self.n = n + + def _compute_L_and_logdet(self, value, *inputs): + n = self.n + counter = 0 + L = pt.eye(n) + log_det = 0 + + for i in range(1, n): + y_star = value[counter : counter + i] + dsy = y_star.dot(y_star) + alpha_r = 1 / (dsy + 1) + gamma = pt.sqrt(dsy + 2) * alpha_r + + x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r)) + L = L[i, : i + 1].set(x) + log_det += pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy) + + counter += i + + # Return whole matrix? Or just lower triangle? + return L, log_det + + def backward(self, value, *inputs): + L, _ = self._compute_L_and_logdet(value, *inputs) + return L + + def forward(self, value, *inputs): + # TODO: This is a placeholder + n = self.n + return pt.as_tensor_variable(np.random.normal(size=(n,))) + + def log_jac_det(self, value, *inputs): + _, log_det = self._compute_L_and_logdet(value, *inputs) + return log_det + + Chain = ChainedTransform simplex = SimplexTransform() From 3e5721c288f5624f53b11a04a9c494fa0860dbdf Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 1 Mar 2025 02:30:27 +0800 Subject: [PATCH 2/3] fix initial point size --- pymc/distributions/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 52c2ce88fe..d26f31cd0b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -207,7 +207,8 @@ def backward(self, value, *inputs): def forward(self, value, *inputs): # TODO: This is a placeholder n = self.n - return pt.as_tensor_variable(np.random.normal(size=(n,))) + size = n * (n - 1) // 2 + return pt.as_tensor_variable(np.random.normal(size=size)) def log_jac_det(self, value, *inputs): _, log_det = self._compute_L_and_logdet(value, *inputs) From ddbc3fee074c571b20d810dd429c4d26837db0ba Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 1 Mar 2025 02:54:42 +0800 Subject: [PATCH 3/3] just use scan bro --- pymc/distributions/transforms.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d26f31cd0b..ad338bf38e 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -16,6 +16,7 @@ from functools import singledispatch import numpy as np +import pytensor import pytensor.tensor as pt @@ -179,6 +180,33 @@ def __init__(self, n): """ self.n = n + def step(self, i, counter, L, y): + y_star = y[counter : counter + i] + dsy = y_star.dot(y_star) + alpha_r = 1 / (dsy + 1) + gamma = pt.sqrt(dsy + 2) * alpha_r + + x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r)) + next_L = L[i, : i + 1].set(x) + log_det = pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy) + + return next_L, log_det + + def _compute_L_and_logdet_scan(self, value, *inputs): + L = pt.eye(self.n) + idxs = pt.arange(1, self.n) + counters = pt.arange(0, self.n).cumsum() + + results, _ = pytensor.scan( + self.step, outputs_info=[L, None], sequences=[idxs, counters], non_sequences=[value] + ) + + L_seq, log_det_seq = results + L = L_seq[-1] + log_det = pt.sum(log_det_seq) + + return L, log_det + def _compute_L_and_logdet(self, value, *inputs): n = self.n counter = 0 @@ -201,7 +229,7 @@ def _compute_L_and_logdet(self, value, *inputs): return L, log_det def backward(self, value, *inputs): - L, _ = self._compute_L_and_logdet(value, *inputs) + L, _ = self._compute_L_and_logdet_scan(value, *inputs) return L def forward(self, value, *inputs): @@ -211,7 +239,7 @@ def forward(self, value, *inputs): return pt.as_tensor_variable(np.random.normal(size=size)) def log_jac_det(self, value, *inputs): - _, log_det = self._compute_L_and_logdet(value, *inputs) + _, log_det = self._compute_L_and_logdet_scan(value, *inputs) return log_det