Skip to content

Commit 2f3c3e1

Browse files
mvargas33mvargas33
authored and
mvargas33
committed
Generalized toy tests
1 parent ec49397 commit 2f3c3e1

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

metric_learn/oasis.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,11 @@ def fit(self, X, y):
1616
y : (n) data labels
1717
"""
1818
X = self._prepare_inputs(X, y, ensure_min_samples=2)
19-
self.components_ = np.identity(np.shape(X[0])[-1]) # Identity matrix
19+
20+
# Handmade dummy fit
21+
#self.components_ = np.identity(np.shape(X[0])[-1]) # Identity matrix
22+
#self.components_ = np.array([[2,4,6], [6,4,2], [1, 2, 3]])
23+
24+
# Dummy fit
25+
self.components_ = np.random.rand(np.shape(X[0])[-1], np.shape(X[0])[-1])
2026
return self

test_bilinear.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
from metric_learn.oasis import OASIS
22
import numpy as np
3+
from numpy.testing import assert_array_almost_equal
34

45
def test_toy_distance():
5-
u = np.array([0, 1, 2])
6-
v = np.array([3, 4, 5])
6+
d = 100
7+
8+
u = np.random.rand(d)
9+
v = np.random.rand(d)
710

811
mixin = OASIS()
9-
mixin.fit([u, v], [0, 0])
10-
#mixin.components_ = np.array([[1, 0, 0],[0, 1, 0],[0, 0, 1]])
12+
mixin.fit([u, v], [0, 0]) # Dummy fit
1113

12-
dist = mixin.score_pairs([[u, v],[v, u]])
13-
print(dist)
14+
# The distances must match, whether calc with get_metric() or score_pairs()
15+
dist1 = mixin.score_pairs([[u, v], [v, u]])
16+
dist2 = [mixin.get_metric()(u, v), mixin.get_metric()(v, u)]
17+
18+
u_v = (np.dot(np.dot(u.T, mixin.get_bilinear_matrix()), v))
19+
v_u = (np.dot(np.dot(v.T, mixin.get_bilinear_matrix()), u))
20+
desired = [u_v, v_u]
21+
22+
assert_array_almost_equal(dist1, desired)
23+
assert_array_almost_equal(dist2, desired)
1424

1525
test_toy_distance()

0 commit comments

Comments
 (0)