-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow Stationary covariance functions to use user defined distance functions #6965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -510,6 +510,9 @@ class Stationary(Covariance): | |
ls: Lengthscale. If input_dim > 1, a list or array of scalars or PyMC random | ||
variables. If input_dim == 1, a scalar or PyMC random variable. | ||
ls_inv: Inverse lengthscale. 1 / ls. One of ls or ls_inv must be provided. | ||
square_dist: An optional (squared) distance function. If None is supplied, the | ||
default is the square of the Euclidean distance. The signature of this | ||
function is `square_dist(X: TensorLike, Xs: Tensorlike, ls: Tensorlike)`. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -518,6 +521,7 @@ def __init__( | |
ls=None, | ||
ls_inv=None, | ||
active_dims: Optional[IntSequence] = None, | ||
square_dist: Optional[Callable] = None, | ||
): | ||
super().__init__(input_dim, active_dims) | ||
if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None): | ||
|
@@ -529,23 +533,24 @@ def __init__( | |
ls = 1.0 / ls_inv | ||
self.ls = pt.as_tensor_variable(ls) | ||
|
||
def square_dist(self, X, Xs): | ||
X = pt.mul(X, 1.0 / self.ls) | ||
X2 = pt.sum(pt.square(X), 1) | ||
if Xs is None: | ||
sqd = -2.0 * pt.dot(X, pt.transpose(X)) + ( | ||
pt.reshape(X2, (-1, 1)) + pt.reshape(X2, (1, -1)) | ||
) | ||
if square_dist is None: | ||
self.square_dist = self.default_square_dist | ||
else: | ||
Xs = pt.mul(Xs, 1.0 / self.ls) | ||
Xs2 = pt.sum(pt.square(Xs), 1) | ||
sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( | ||
pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) | ||
) | ||
self.square_dist = square_dist | ||
|
||
@staticmethod | ||
def default_square_dist(X, Xs, ls): | ||
X = pt.mul(X, 1.0 / ls) | ||
X2 = pt.sum(pt.square(X), 1) | ||
Xs = pt.mul(Xs, 1.0 / ls) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you take out the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved it down to line 564 (and your one-liner is nicer). I moved it because I thought it'd be easier for someone making a different distance function if the inputs didn't have defaults. |
||
Xs2 = pt.sum(pt.square(Xs), 1) | ||
sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( | ||
pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) | ||
) | ||
return pt.clip(sqd, 0.0, np.inf) | ||
|
||
def euclidean_dist(self, X, Xs): | ||
r2 = self.square_dist(X, Xs) | ||
r2 = self.square_dist(X, Xs, self.ls) | ||
return self._sqrt(r2) | ||
|
||
def _sqrt(self, r2): | ||
|
@@ -556,7 +561,9 @@ def diag(self, X: TensorLike) -> TensorVariable: | |
|
||
def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: | ||
X, Xs = self._slice(X, Xs) | ||
r2 = self.square_dist(X, Xs) | ||
if Xs is None: | ||
Xs = X | ||
r2 = self.square_dist(X, Xs, self.ls) | ||
return self.full_from_distance(r2, squared=True) | ||
|
||
def full_from_distance(self, dist: TensorLike, squared: bool = False) -> TensorVariable: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -408,10 +408,36 @@ def test_stable(self): | |
X = np.random.uniform(low=320.0, high=400.0, size=[2000, 2]) | ||
with pm.Model() as model: | ||
cov = pm.gp.cov.ExpQuad(2, 0.1) | ||
dists = cov.square_dist(X, X).eval() | ||
dists = cov.square_dist(X, X, ls=cov.ls).eval() | ||
assert not np.any(dists < 0) | ||
|
||
|
||
class TestDistance: | ||
def test_alt_distance(self): | ||
"""square_dist below is the same as the default. Check if we get the same | ||
result by passing it as an argument to covariance func that inherets from | ||
Stationary. | ||
""" | ||
|
||
def square_dist(X, Xs, ls): | ||
X = pt.mul(X, 1.0 / ls) | ||
Xs = pt.mul(Xs, 1.0 / ls) | ||
X2 = pt.sum(pt.square(X), 1) | ||
Xs2 = pt.sum(pt.square(Xs), 1) | ||
sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( | ||
pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) | ||
) | ||
return pt.clip(sqd, 0.0, np.inf) | ||
|
||
X = np.linspace(-5, 5, 100)[:, None] | ||
with pm.Model() as model: | ||
cov1 = pm.gp.cov.Matern32(1, ls=1) | ||
cov2 = pm.gp.cov.Matern32(1, ls=1, square_dist=square_dist) | ||
K1 = cov1(X).eval() | ||
K2 = cov2(X, X).eval() | ||
npt.assert_allclose(K1, K2, atol=1e-5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also test you can get something different from the default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with this |
||
|
||
|
||
class TestExpQuad: | ||
def test_1d(self): | ||
X = np.linspace(0, 1, 10)[:, None] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't see where the
square_dist
argument is being used later. Do you want to attach it to the covariance instance at the end of__init__
? As it stands now, it looks like this isn't doing anything. Furthermore, even if you attach it to the object, would that be called instead of the__class__
square_dist
accessor?