From 51d5c72ec9abb9f15c564e06b21eb8c78b0e16a3 Mon Sep 17 00:00:00 2001 From: Bill Engels Date: Fri, 20 Oct 2023 14:34:57 -0700 Subject: [PATCH 1/3] Allow covariances that inherit from Stationary to be passed different pytensor distance functions --- pymc/gp/cov.py | 30 ++++++++++++++++-------------- tests/gp/test_cov.py | 27 ++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index 905fdef930..846367c78e 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -510,6 +510,9 @@ class Stationary(Covariance): ls: Lengthscale. If input_dim > 1, a list or array of scalars or PyMC random variables. If input_dim == 1, a scalar or PyMC random variable. ls_inv: Inverse lengthscale. 1 / ls. One of ls or ls_inv must be provided. + square_dist: An optional (squared) distance function. If None is supplied, the + default is the square of the Euclidean distance. The signature of this + function is `square_dist(X: TensorLike, Xs: Tensorlike, ls: Tensorlike)`. """ def __init__( @@ -518,6 +521,7 @@ def __init__( ls=None, ls_inv=None, active_dims: Optional[IntSequence] = None, + square_dist: Optional[Callable] = None, ): super().__init__(input_dim, active_dims) if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None): @@ -529,23 +533,19 @@ def __init__( ls = 1.0 / ls_inv self.ls = pt.as_tensor_variable(ls) - def square_dist(self, X, Xs): - X = pt.mul(X, 1.0 / self.ls) + @staticmethod + def square_dist(X, Xs, ls): + X = pt.mul(X, 1.0 / ls) X2 = pt.sum(pt.square(X), 1) - if Xs is None: - sqd = -2.0 * pt.dot(X, pt.transpose(X)) + ( - pt.reshape(X2, (-1, 1)) + pt.reshape(X2, (1, -1)) - ) - else: - Xs = pt.mul(Xs, 1.0 / self.ls) - Xs2 = pt.sum(pt.square(Xs), 1) - sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( - pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) - ) + Xs = pt.mul(Xs, 1.0 / ls) + Xs2 = pt.sum(pt.square(Xs), 1) + sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( + pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) + ) return pt.clip(sqd, 0.0, np.inf) def euclidean_dist(self, X, Xs): - r2 = self.square_dist(X, Xs) + r2 = self.square_dist(X, Xs, self.ls) return self._sqrt(r2) def _sqrt(self, r2): @@ -556,7 +556,9 @@ def diag(self, X: TensorLike) -> TensorVariable: def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: X, Xs = self._slice(X, Xs) - r2 = self.square_dist(X, Xs) + if Xs is None: + Xs = X + r2 = self.square_dist(X, Xs, self.ls) return self.full_from_distance(r2, squared=True) def full_from_distance(self, dist: TensorLike, squared: bool = False) -> TensorVariable: diff --git a/tests/gp/test_cov.py b/tests/gp/test_cov.py index 750d5fe6a1..ad308a1866 100644 --- a/tests/gp/test_cov.py +++ b/tests/gp/test_cov.py @@ -408,10 +408,35 @@ def test_stable(self): X = np.random.uniform(low=320.0, high=400.0, size=[2000, 2]) with pm.Model() as model: cov = pm.gp.cov.ExpQuad(2, 0.1) - dists = cov.square_dist(X, X).eval() + dists = cov.square_dist(X, X, ls=cov.ls).eval() assert not np.any(dists < 0) +class TestDistance: + def test_alt_distance(self): + """ square_dist below is the same as the default. Check if we get the same + result by passing it as an argument to covariance func that inherets from + Stationary. + """ + def square_dist(X, Xs, ls): + X = pt.mul(X, 1.0 / ls) + Xs = pt.mul(Xs, 1.0 / ls) + X2 = pt.sum(pt.square(X), 1) + Xs2 = pt.sum(pt.square(Xs), 1) + sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( + pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) + ) + return pt.clip(sqd, 0.0, np.inf) + + X = np.linspace(-5, 5, 100)[:, None] + with pm.Model() as model: + cov1 = pm.gp.cov.Matern32(1, ls=1) + cov2 = pm.gp.cov.Matern32(1, ls=1, square_dist=square_dist) + K1 = cov1(X).eval() + K2 = cov2(X, X).eval() + npt.assert_allclose(K1, K2, atol=1e-5) + + class TestExpQuad: def test_1d(self): X = np.linspace(0, 1, 10)[:, None] From bf34f8fc6ed1ce5061353f68d1ad2601b97bafe3 Mon Sep 17 00:00:00 2001 From: Bill Engels Date: Fri, 20 Oct 2023 14:38:51 -0700 Subject: [PATCH 2/3] run precommit --- pymc/gp/cov.py | 4 ++-- tests/gp/test_cov.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index 846367c78e..3e1325a405 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -510,8 +510,8 @@ class Stationary(Covariance): ls: Lengthscale. If input_dim > 1, a list or array of scalars or PyMC random variables. If input_dim == 1, a scalar or PyMC random variable. ls_inv: Inverse lengthscale. 1 / ls. One of ls or ls_inv must be provided. - square_dist: An optional (squared) distance function. If None is supplied, the - default is the square of the Euclidean distance. The signature of this + square_dist: An optional (squared) distance function. If None is supplied, the + default is the square of the Euclidean distance. The signature of this function is `square_dist(X: TensorLike, Xs: Tensorlike, ls: Tensorlike)`. """ diff --git a/tests/gp/test_cov.py b/tests/gp/test_cov.py index ad308a1866..27345620e5 100644 --- a/tests/gp/test_cov.py +++ b/tests/gp/test_cov.py @@ -414,10 +414,11 @@ def test_stable(self): class TestDistance: def test_alt_distance(self): - """ square_dist below is the same as the default. Check if we get the same - result by passing it as an argument to covariance func that inherets from + """square_dist below is the same as the default. Check if we get the same + result by passing it as an argument to covariance func that inherets from Stationary. """ + def square_dist(X, Xs, ls): X = pt.mul(X, 1.0 / ls) Xs = pt.mul(Xs, 1.0 / ls) From d0cd26c94cad46867d69b30b445bba7699af9f14 Mon Sep 17 00:00:00 2001 From: Bill Engels Date: Tue, 24 Oct 2023 13:18:36 -0700 Subject: [PATCH 3/3] actually change the dist func --- pymc/gp/cov.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index 3e1325a405..b9ce03150c 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -533,8 +533,13 @@ def __init__( ls = 1.0 / ls_inv self.ls = pt.as_tensor_variable(ls) + if square_dist is None: + self.square_dist = self.default_square_dist + else: + self.square_dist = square_dist + @staticmethod - def square_dist(X, Xs, ls): + def default_square_dist(X, Xs, ls): X = pt.mul(X, 1.0 / ls) X2 = pt.sum(pt.square(X), 1) Xs = pt.mul(Xs, 1.0 / ls)