From 79fa542c54a7648f824aa24bbf2d44dc43bf80fa Mon Sep 17 00:00:00 2001 From: John Cant Date: Fri, 21 Jun 2024 14:35:12 +0100 Subject: [PATCH 1/4] Port TF bijector to ensure posdef LKJCorr samples --- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 111 +++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 359b0743dd..5b1934d3e1 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1579,7 +1579,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - return MultivariateIntervalTransform(-1.0, 1.0) + _, _, _, n, *_ = rv.owner.inputs + n = n.eval() + return transforms.CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d8998889cf..74a4bfd961 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -17,6 +17,7 @@ import numpy as np import pytensor.tensor as pt +import pytensor # ignore mypy error because it somehow considers that @@ -45,6 +46,7 @@ "log", "sum_to_1", "circular", + "CholeskyCorr", "CholeskyCovPacked", "Chain", "ZeroSumTransform", @@ -138,6 +140,115 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) +class CholeskyCorr(Transform): + """ + Transforms the off-diagonal elements of a correlation matrix to + unconstrained real numbers. + + Note: This is not particular to the LKJ distribution - it is only a + transform to help generate cholesky decompositions for random valid + correlation matrices. + + Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + + The backward side of this transformation is the off-diagonal upper + triangular elements of a correlation matrix, specified in row major order. + """ + + name = "cholesky-corr" + + def __init__(self, n): + """ + + Parameters + ---------- + n: int + Size of correlation matrix + """ + self.n = n + self.m = int(n*(n-1)/2) # number of off-diagonal elements + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + + def _generate_tril_indices(self): + row_indices, col_indices = np.tril_indices(self.n, -1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _generate_triu_indices(self): + row_indices, col_indices = np.triu_indices(self.n, 1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _jacobian(self, value, *inputs): + return pt.jacobian( + self.backward(value), + wrt=value + ) + + def log_jac_det(self, value, *inputs): + """ + Compute log of the determinant of the jacobian. + + There are no clever tricks here - we literally compute the jacobian + then compute its determinant then take log. + """ + jac = self._jacobian(value) + return pt.log(pt.linalg.det(jac)) + + def forward(self, value, *inputs): + """ + Convert the off-diagonal elements of a cholesky decomposition of a + correlation matrix to unconstrained real numbers. + """ + # The correlation matrix is specified via its upper triangular elements + corr = pt.set_subtensor( + pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], + value + ) + corr = corr + corr.T + pt.eye(self.n) + + chol = pt.linalg.cholesky(corr) + + # Are the diagonals always guaranteed to be positive? + # I don't know, so we'll use abs + row_norms = 1/pt.abs(pt.diag(chol)) + + # Multiply by the row norms to undo the normalization + unconstrained = chol*row_norms[:, pt.newaxis] + + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + + def backward(self, value, *inputs, foo=False): + """ + Convert unconstrained real numbers to the off-diagonal elements of the + cholesky decomposition of a correlation matrix. + """ + # The diagonals of this matrix are 1, but these ones are just used for + # computing a denominator. The diagonals of the cholesky factor are not + # returned, but they are not ones. + chol_pre_norm = pt.set_subtensor( + pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], + value + ) + + # derivative of pt.linalg.norm ended up complex, which caused errors +# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") + + row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) + chol = chol_pre_norm / row_norm[:, pt.newaxis] + + # Undo the cholesky decomposition + corr = pt.matmul(chol, chol.T) + + # We want the upper triangular indices here. + return corr[self.triu_r_idxs, self.triu_c_idxs] + + class CholeskyCovPacked(Transform): """ Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the From dcd3a8d94dd268a40af5c5fe7e3b07d1ce55ba82 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sat, 14 Sep 2024 17:07:12 +0100 Subject: [PATCH 2/4] Use GPT o1 to finish PR. --- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 150 ++++++++++++++------------ tests/distributions/test_transform.py | 136 +++++++++++++++++++++++ 3 files changed, 221 insertions(+), 69 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 5b1934d3e1..bef485f1bc 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1580,8 +1580,8 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): _, _, _, n, *_ = rv.owner.inputs - n = n.eval() - return transforms.CholeskyCorr(n) + n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval + return CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 74a4bfd961..7856789634 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -142,111 +142,127 @@ def log_jac_det(self, value, *inputs): class CholeskyCorr(Transform): """ - Transforms the off-diagonal elements of a correlation matrix to - unconstrained real numbers. + Transforms unconstrained real numbers to the off-diagonal elements of + a Cholesky decomposition of a correlation matrix. - Note: This is not particular to the LKJ distribution - it is only a - transform to help generate cholesky decompositions for random valid - correlation matrices. + This ensures that the resulting correlation matrix is positive definite. - Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + #### Mathematical Details - The backward side of this transformation is the off-diagonal upper - triangular elements of a correlation matrix, specified in row major order. + [Include detailed mathematical explanations similar to the original TFP bijector.] + + #### Examples + + ```python + transform = CholeskyCorr(n=3) + x = pt.as_tensor_variable([0.0, 0.0, 0.0]) + y = transform.forward(x).eval() + # y will be the off-diagonal elements of the Cholesky factor + + x_reconstructed = transform.backward(y).eval() + # x_reconstructed should closely match the original x + ``` + + #### References + - [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) + - Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. """ name = "cholesky-corr" - def __init__(self, n): + def __init__(self, n, validate_args=False): """ + Initialize the CholeskyCorr transform. Parameters ---------- - n: int - Size of correlation matrix + n : int + Size of the correlation matrix. + validate_args : bool, default False + Whether to validate input arguments. """ self.n = n - self.m = int(n*(n-1)/2) # number of off-diagonal elements + self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + super().__init__(validate_args=validate_args) def _generate_tril_indices(self): row_indices, col_indices = np.tril_indices(self.n, -1) - return ( - pytensor.shared(row_indices), - pytensor.shared(col_indices) - ) + return (row_indices, col_indices) def _generate_triu_indices(self): row_indices, col_indices = np.triu_indices(self.n, 1) - return ( - pytensor.shared(row_indices), - pytensor.shared(col_indices) - ) - - def _jacobian(self, value, *inputs): - return pt.jacobian( - self.backward(value), - wrt=value - ) + return (row_indices, col_indices) - def log_jac_det(self, value, *inputs): + def forward(self, x, *inputs): """ - Compute log of the determinant of the jacobian. + Forward transform: Unconstrained real numbers to Cholesky factors. - There are no clever tricks here - we literally compute the jacobian - then compute its determinant then take log. - """ - jac = self._jacobian(value) - return pt.log(pt.linalg.det(jac)) + Parameters + ---------- + x : tensor + Unconstrained real numbers. - def forward(self, value, *inputs): + Returns + ------- + tensor + Transformed Cholesky factors. """ - Convert the off-diagonal elements of a cholesky decomposition of a - correlation matrix to unconstrained real numbers. - """ - # The correlation matrix is specified via its upper triangular elements - corr = pt.set_subtensor( - pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], - value + # Initialize a zero matrix + chol = pt.zeros((self.n, self.n), dtype=x.dtype) + + # Assign the unconstrained values to the lower triangular part + chol = pt.set_subtensor( + chol[self.tril_r_idxs, self.tril_c_idxs], + x ) - corr = corr + corr.T + pt.eye(self.n) - chol = pt.linalg.cholesky(corr) + # Normalize each row to have unit L2 norm + row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True)) + chol = chol / row_norms - # Are the diagonals always guaranteed to be positive? - # I don't know, so we'll use abs - row_norms = 1/pt.abs(pt.diag(chol)) + return chol[self.tril_r_idxs, self.tril_c_idxs] - # Multiply by the row norms to undo the normalization - unconstrained = chol*row_norms[:, pt.newaxis] + def backward(self, y, *inputs): + """ + Backward transform: Cholesky factors to unconstrained real numbers. - return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + Parameters + ---------- + y : tensor + Cholesky factors. - def backward(self, value, *inputs, foo=False): - """ - Convert unconstrained real numbers to the off-diagonal elements of the - cholesky decomposition of a correlation matrix. + Returns + ------- + tensor + Unconstrained real numbers. """ - # The diagonals of this matrix are 1, but these ones are just used for - # computing a denominator. The diagonals of the cholesky factor are not - # returned, but they are not ones. - chol_pre_norm = pt.set_subtensor( - pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], - value + # Reconstruct the full Cholesky matrix + chol = pt.zeros((self.n, self.n), dtype=y.dtype) + chol = pt.set_subtensor( + chol[self.triu_r_idxs, self.triu_c_idxs], + y ) + chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) + + # Perform Cholesky decomposition + chol = pt.linalg.cholesky(chol) - # derivative of pt.linalg.norm ended up complex, which caused errors -# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") + # Extract the unconstrained parameters by normalizing + row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1)) + unconstrained = chol / row_norms[:, None] - row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) - chol = chol_pre_norm / row_norm[:, pt.newaxis] + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] - # Undo the cholesky decomposition - corr = pt.matmul(chol, chol.T) + def log_jac_det(self, y, *inputs): + """ + Compute the log determinant of the Jacobian. - # We want the upper triangular indices here. - return corr[self.triu_r_idxs, self.triu_c_idxs] + The Jacobian determinant for normalization is the product of row norms. + """ + row_norms = pt.sqrt(pt.sum(y ** 2, axis=1)) + return -pt.sum(pt.log(row_norms), axis=-1) class CholeskyCovPacked(Transform): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 8d464f206a..e2d141d9f1 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -23,6 +23,7 @@ import pymc as pm import pymc.distributions.transforms as tr +from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform @@ -673,3 +674,138 @@ def test_deprecated_ndim_supp_transforms(): with pytest.warns(FutureWarning, match="deprecated"): assert tr.multivariate_sum_to_1 == tr.sum_to_1 + + +def test_lkjcorr_transform_round_trip(): + """ + Test that applying the forward transform followed by the backward transform + retrieves the original unconstrained parameters, and that sampled matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=3, eta=2) + + trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + # Perform round-trip transform: forward and then backward + transform = CholeskyCorr(n=3) + unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval() + reconstructed = transform.backward(unconstrained).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(sample_matrix, reconstructed, atol=1e-6) + + +def test_lkjcorr_log_jac_det(): + """ + Verify that the computed log determinant of the Jacobian matches the expected closed-form solution. + """ + n = 3 + transform = CholeskyCorr(n=n) + + # Create a sample unconstrained vector (all zeros for simplicity) + x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform to obtain Cholesky factors + y = transform.forward(x_tensor).eval() + + # Compute the log determinant using the transform's method + computed_log_jac_det = transform.log_jac_det(y).eval() + + # Expected log determinant: 0 (since row norms are 1) + expected_log_jac_det = 0.0 + + assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6) + + +@pytest.mark.parametrize("n", [2, 4, 5]) +def test_lkjcorr_transform_various_sizes(n): + """ + Test the CholeskyCorr transform with various sizes of correlation matrices. + """ + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6) + + +def test_lkjcorr_invalid_n(): + """ + Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors. + """ + with pytest.raises(ValueError): + # 'n' must be an integer greater than 1 + CholeskyCorr(n=1) + + with pytest.raises(TypeError): + # 'n' must be an integer + CholeskyCorr(n='three') + + +def test_lkjcorr_positive_definite(): + """ + Ensure that all sampled correlation matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=4, eta=2) + + trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + +def test_lkjcorr_round_trip_various_sizes(): + """ + Perform round-trip transformation tests for various sizes of correlation matrices. + """ + for n in [2, 3, 4]: + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6) \ No newline at end of file From 408adba8cbbefcd6bfcea611c8fb823810fd083e Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sat, 14 Sep 2024 17:17:30 +0100 Subject: [PATCH 3/4] Linter fixes. --- pymc/distributions/multivariate.py | 7 ++++++- pymc/distributions/transforms.py | 17 +++++------------ tests/distributions/test_transform.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index bef485f1bc..686b063cb9 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -69,7 +69,12 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform +from pymc.distributions.transforms import ( + CholeskyCorr, + Interval, + ZeroSumTransform, + _default_transform, +) from pymc.logprob.abstract import _logprob from pymc.math import kron_diag, kron_dot from pymc.pytensorf import normalize_rng_param diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 7856789634..51e8bd5dc4 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -17,7 +17,6 @@ import numpy as np import pytensor.tensor as pt -import pytensor # ignore mypy error because it somehow considers that @@ -213,13 +212,10 @@ def forward(self, x, *inputs): chol = pt.zeros((self.n, self.n), dtype=x.dtype) # Assign the unconstrained values to the lower triangular part - chol = pt.set_subtensor( - chol[self.tril_r_idxs, self.tril_c_idxs], - x - ) + chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) # Normalize each row to have unit L2 norm - row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True)) + row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) chol = chol / row_norms return chol[self.tril_r_idxs, self.tril_c_idxs] @@ -240,17 +236,14 @@ def backward(self, y, *inputs): """ # Reconstruct the full Cholesky matrix chol = pt.zeros((self.n, self.n), dtype=y.dtype) - chol = pt.set_subtensor( - chol[self.triu_r_idxs, self.triu_c_idxs], - y - ) + chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) # Perform Cholesky decomposition chol = pt.linalg.cholesky(chol) # Extract the unconstrained parameters by normalizing - row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1)) + row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) unconstrained = chol / row_norms[:, None] return unconstrained[self.tril_r_idxs, self.tril_c_idxs] @@ -261,7 +254,7 @@ def log_jac_det(self, y, *inputs): The Jacobian determinant for normalization is the product of row norms. """ - row_norms = pt.sqrt(pt.sum(y ** 2, axis=1)) + row_norms = pt.sqrt(pt.sum(y**2, axis=1)) return -pt.sum(pt.log(row_norms), axis=-1) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index e2d141d9f1..f1c84aabe4 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -23,8 +23,8 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.distributions.transforms import CholeskyCorr +from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian @@ -684,7 +684,9 @@ def test_lkjcorr_transform_round_trip(): with pm.Model() as model: rho = pm.LKJCorr("rho", n=3, eta=2) - trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) # Extract the sampled correlation matrices rho_samples = trace["rho"] @@ -763,7 +765,7 @@ def test_lkjcorr_invalid_n(): with pytest.raises(TypeError): # 'n' must be an integer - CholeskyCorr(n='three') + CholeskyCorr(n="three") def test_lkjcorr_positive_definite(): @@ -773,7 +775,9 @@ def test_lkjcorr_positive_definite(): with pm.Model() as model: rho = pm.LKJCorr("rho", n=4, eta=2) - trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) # Extract the sampled correlation matrices rho_samples = trace["rho"] @@ -808,4 +812,4 @@ def test_lkjcorr_round_trip_various_sizes(): reconstructed = transform.backward(y).eval() # Assert that the original and reconstructed unconstrained parameters are close - assert_allclose(x, reconstructed, atol=1e-6) \ No newline at end of file + assert_allclose(x, reconstructed, atol=1e-6) From df723bc0a9c165babf0d0b773246e9b6dbb4fc80 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sun, 15 Sep 2024 08:48:18 +0100 Subject: [PATCH 4/4] Update doc string. Ask o1-mini to improve test. --- pymc/distributions/transforms.py | 52 ++++++++++++++++++++++++++- tests/distributions/test_transform.py | 27 ++++++++++---- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 51e8bd5dc4..fe6010b710 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -148,7 +148,57 @@ class CholeskyCorr(Transform): #### Mathematical Details - [Include detailed mathematical explanations similar to the original TFP bijector.] + This bijector provides a change of variables from unconstrained reals to a + parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution + [1] is a distribution on the set of Cholesky factors of positive definite + correlation matrices. The CholeskyLKJ probability density function is + obtained from the LKJ density on n x n matrices as follows: + + 1 = int p(A | eta) dA + = int Z(eta) * det(A) ** (eta - 1) dA + = int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n) + + where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the + correlation matrix A; and ^dL_ij denotes the wedge product (or differential) + of the strictly lower triangular entries of L. The entries L_ij are + constrained such that each entry lies in [-1, 1] and the norm of each row is + 1. The norm includes the diagonal; which is not included in the wedge product. + To preserve uniqueness, we further specify that the diagonal entries are + positive. + + The image of unconstrained reals under the `CorrelationCholesky` bijector is + the set of correlation matrices which are positive definite. A [correlation + matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices) + can be characterized as a symmetric positive semidefinite matrix with 1s on + the main diagonal. + + For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive + definite correlation matrix, it is necessary and sufficient that each row of + `L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the + `i`th row of the Cholesky factor corresponding to the correlation matrix `R`, + then the `i`th diagonal entry of `R` satisfies: + + 1 = R_i,i = L_i . L_i = ||L_i||^2 + + where '.' is the dot product of vectors and `||...||` denotes the Euclidean + norm. + + Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the + Cauchy-Schwarz inequality: + + |R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1 + + This is a consequence of the fact that `R` is symmetric positive definite with + 1s on the main diagonal. + + We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the + `(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`. + + L_ij = x_i,j / s_i (for i < j) + L_ii = 1 / s_i + + where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that + the required constraints on the image are satisfied. #### Examples diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index f1c84aabe4..617d2bf134 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -712,25 +712,38 @@ def test_lkjcorr_transform_round_trip(): def test_lkjcorr_log_jac_det(): """ - Verify that the computed log determinant of the Jacobian matches the expected closed-form solution. + Verify that the computed log determinant of the Jacobian matches the expected value + obtained from PyTensor's automatic differentiation with a non-trivial input. """ n = 3 transform = CholeskyCorr(n=n) - # Create a sample unconstrained vector (all zeros for simplicity) - x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX) + # Create a non-trivial sample unconstrained vector + x = np.random.randn(int(n * (n - 1) / 2)).astype(pytensor.config.floatX) x_tensor = pt.as_tensor_variable(x) # Perform forward transform to obtain Cholesky factors - y = transform.forward(x_tensor).eval() + y = transform.forward(x_tensor) # Compute the log determinant using the transform's method computed_log_jac_det = transform.log_jac_det(y).eval() - # Expected log determinant: 0 (since row norms are 1) - expected_log_jac_det = 0.0 + # Define the backward function + backward = transform.backward + + # Compute the Jacobian matrix using PyTensor's automatic differentiation + backward_transformed = backward(y) + jacobian_matrix = pt.jacobian(backward_transformed, y) + + # Compile the function to compute the Jacobian matrix + jacobian_func = pytensor.function([], jacobian_matrix) + jacobian_val = jacobian_func() + + # Compute the log determinant of the Jacobian matrix + actual_log_jac_det = np.log(np.abs(np.linalg.det(jacobian_val))) - assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6) + # Compare the two + assert_allclose(computed_log_jac_det, actual_log_jac_det, atol=1e-6) @pytest.mark.parametrize("n", [2, 4, 5])