Skip to content

Commit 5a76df5

Browse files
bwengalstwiecki
authored andcommitted
adds clip fix for numerical instability
1 parent fecf1f4 commit 5a76df5

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

pymc3/gp/cov.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,14 @@ def square_dist(self, X, Z):
126126
X = tt.mul(X, 1.0 / self.lengthscales)
127127
Xs = tt.sum(tt.square(X), 1)
128128
if Z is None:
129-
return -2.0 * tt.dot(X, tt.transpose(X)) +\
130-
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Xs, (1, -1)))
129+
sqd = -2.0 * tt.dot(X, tt.transpose(X)) +\
130+
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Xs, (1, -1)))
131131
else:
132132
Z = tt.mul(Z, 1.0 / self.lengthscales)
133133
Zs = tt.sum(tt.square(Z), 1)
134-
return -2.0 * tt.dot(X, tt.transpose(Z)) +\
135-
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
134+
sqd = -2.0 * tt.dot(X, tt.transpose(Z)) +\
135+
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
136+
return tt.clip(sqd, 0.0, np.inf)
136137

137138
def euclidean_dist(self, X, Z):
138139
r2 = self.square_dist(X, Z)
@@ -337,13 +338,14 @@ def square_dist(self, X, Z):
337338
X = tt.as_tensor_variable(X)
338339
Xs = tt.sum(tt.square(X), 1)
339340
if Z is None:
340-
return -2.0 * tt.dot(X, tt.transpose(X)) +\
341-
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Xs, (1, -1)))
341+
sqd = -2.0 * tt.dot(X, tt.transpose(X)) +\
342+
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Xs, (1, -1)))
342343
else:
343344
Z = tt.as_tensor_variable(Z)
344345
Zs = tt.sum(tt.square(Z), 1)
345-
return -2.0 * tt.dot(X, tt.transpose(Z)) +\
346-
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
346+
sqd = -2.0 * tt.dot(X, tt.transpose(Z)) +\
347+
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
348+
return tt.clip(sqd, 0.0, np.inf)
347349

348350
def __call__(self, X, Z=None):
349351
X, Z = self._slice(X, Z)

pymc3/tests/test_gp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ def test_raises(self):
169169
gp.cov.ExpQuad(2, lengthscales, [True])
170170

171171

172+
class TestStability(unittest.TestCase):
173+
def test_stable(self):
174+
X = np.random.uniform(low=320., high=400., size=[2000,2])
175+
with Model() as model:
176+
cov = gp.cov.ExpQuad(2, 0.1)
177+
dists = theano.function([], cov.square_dist(X, X))()
178+
self.assertFalse(np.any(dists < 0))
179+
180+
172181
class TestExpQuad(unittest.TestCase):
173182
def test_1d(self):
174183
X = np.linspace(0,1,10)[:,None]

0 commit comments

Comments
 (0)