diff --git a/pymc_experimental/tests/test_pivoted_cholesky.py b/pymc_experimental/tests/test_pivoted_cholesky.py new file mode 100644 index 000000000..88ad75ce5 --- /dev/null +++ b/pymc_experimental/tests/test_pivoted_cholesky.py @@ -0,0 +1,24 @@ +# try: +# import gpytorch +# import torch +# except ImportError as e: +# # print( +# # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}" +# # ) +# pass +# import numpy as np +# +# import pymc_experimental as pmx +# +# +# def test_match_gpytorch_linearcg_output(): +# N = 10 +# rank = 5 +# np.random.seed(1234) # nans with seed 1234 +# K = np.random.randn(N, N) +# K = K @ K.T + N * np.eye(N) +# K_torch = torch.from_numpy(K) +# +# L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3) +# L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3) +# assert np.allclose(L_gpt, L_np.T) diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index ec35cb8d4..db751aa2a 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -15,3 +15,5 @@ from pymc_experimental.utils import prior, spline from pymc_experimental.utils.linear_cg import linear_cg + +# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky diff --git a/pymc_experimental/utils/pivoted_cholesky.py b/pymc_experimental/utils/pivoted_cholesky.py new file mode 100644 index 000000000..69ea9cd7b --- /dev/null +++ b/pymc_experimental/utils/pivoted_cholesky.py @@ -0,0 +1,66 @@ +try: + import torch + from gpytorch.utils.permutation import apply_permutation +except ImportError as e: + raise ImportError("PyTorch and GPyTorch not found.") from e + +import numpy as np + +pp = lambda x: np.array2string(x, precision=4, floatmode="fixed") + + +def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf): + """ + mat: numpy matrix of N x N + + This is to replicate what is done in GPyTorch verbatim. + """ + n = mat.shape[-1] + max_iter = min(int(max_iter), n) + + d = np.array(np.diag(mat)) + orig_error = np.max(d) + error = np.linalg.norm(d, 1) / orig_error + pi = np.arange(n) + + L = np.zeros((max_iter, n)) + + m = 0 + while m < max_iter and error > error_tol: + permuted_d = d[pi] + max_diag_idx = np.argmax(permuted_d[m:]) + max_diag_idx = max_diag_idx + m + max_diag_val = permuted_d[max_diag_idx] + i = max_diag_idx + + # swap pi_m and pi_i + pi[m], pi[i] = pi[i], pi[m] + pim = pi[m] + + L[m, pim] = np.sqrt(max_diag_val) + + if m + 1 < n: + row = apply_permutation( + torch.from_numpy(mat), torch.tensor(pim), right_permutation=None + ) # left permutation just swaps row + row = row.numpy().flatten() + pi_i = pi[m + 1 :] + L_m_new = row[pi_i] # length = 9 + + if m > 0: + L_prev = L[:m, pi_i] + update = L[:m, pim] + prod = update @ L_prev + L_m_new = L_m_new - prod # np.sum(prod, axis=-1) + + L_m = L[m, :] + L_m_new = L_m_new / L_m[pim] + L_m[pi_i] = L_m_new + + matrix_diag_current = d[pi_i] + d[pi_i] = matrix_diag_current - L_m_new**2 + + L[m, :] = L_m + error = np.linalg.norm(d[pi_i], 1) / orig_error + m = m + 1 + return L, pi