From 59eb6b6e3ce19837c9c07f6b1a842432171ca7fe Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Mon, 14 Oct 2024 21:50:23 -0700 Subject: [PATCH] Add Cholesky --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/slinalg.py | 28 ++++++++++++++++++++++ tests/link/pytorch/test_slinalg.py | 23 ++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/slinalg.py create mode 100644 tests/link/pytorch/test_slinalg.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fddded525a..f732848afc 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -11,4 +11,5 @@ import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor +import pytensor.link.pytorch.dispatch.slinalg # isort: on diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py new file mode 100644 index 0000000000..c13a171094 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -0,0 +1,28 @@ +import torch.linalg + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.slinalg import Cholesky, SolveTriangular + + +@pytorch_funcify.register(Cholesky) +def pytorch_funcify_Cholesky(op, **kwargs): + lower = op.lower + + def cholesky(a, lower=lower): + return torch.linalg.cholesky(a, upper=not lower) + + return cholesky + + +@pytorch_funcify.register(SolveTriangular) +def pytorch_funcify_SolveTriangular(op, **kwargs): + lower = op.lower + trans = op.trans + unit_diagonal = op.unit_diagonal + + def solve_triangular(A, b): + return torch.linalg.solve_triangular( + A, b, upper=not lower, unit_triangle=unit_diagonal, left=trans == "T" + ) + + return solve_triangular diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py new file mode 100644 index 0000000000..9c78cbef21 --- /dev/null +++ b/tests/link/pytorch/test_slinalg.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest + +import pytensor +from pytensor.tensor import tensor +from pytensor.tensor.slinalg import cholesky + + +torch = pytest.importorskip("torch") + + +# @todo: We don't have blockwise yet for torch +def test_batched_mvnormal_logp_and_dlogp(): + rng = np.random.default_rng(sum(map(ord, "mvnormal"))) + + cov = tensor("cov", shape=(10, 10)) + + test_values = np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)) + + chol_cov = cholesky(cov, lower=True, on_error="raise") + + fn = pytensor.function([cov], [chol_cov], mode="PYTORCH") + assert np.all(np.isclose(fn(test_values), np.linalg.cholesky(test_values)))