Skip to content

Commit ae562e6

Browse files
author
mvargas33
committed
Following the correct testing structure
1 parent b1edc46 commit ae562e6

File tree

3 files changed

+64
-121
lines changed

3 files changed

+64
-121
lines changed

metric_learn/oasis.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

test/test_bilinear_mixin.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from metric_learn.base_metric import BilinearMixin
2+
import numpy as np
3+
from numpy.testing import assert_array_almost_equal
4+
5+
class IdentityBilinearMixin(BilinearMixin):
6+
"""A simple Identity bilinear mixin that returns an identity matrix
7+
M as learned. Can change M for a random matrix calling random_M.
8+
Class for testing purposes.
9+
"""
10+
def __init__(self, preprocessor=None):
11+
super().__init__(preprocessor=preprocessor)
12+
13+
def fit(self, X, y):
14+
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
15+
self.d = np.shape(X[0])[-1]
16+
self.components_ = np.identity(self.d)
17+
return self
18+
19+
def random_M(self):
20+
self.components_ = np.random.rand(self.d, self.d)
21+
22+
def test_same_similarity_with_two_methods():
23+
d = 100
24+
u = np.random.rand(d)
25+
v = np.random.rand(d)
26+
mixin = IdentityBilinearMixin()
27+
mixin.fit([u, v], [0, 0]) # Dummy fit
28+
mixin.random_M()
29+
30+
# The distances must match, whether calc with get_metric() or score_pairs()
31+
dist1 = mixin.score_pairs([[u, v], [v, u]])
32+
dist2 = [mixin.get_metric()(u, v), mixin.get_metric()(v, u)]
33+
34+
assert_array_almost_equal(dist1, dist2)
35+
36+
def test_check_correctness_similarity():
37+
d = 100
38+
u = np.random.rand(d)
39+
v = np.random.rand(d)
40+
mixin = IdentityBilinearMixin()
41+
mixin.fit([u, v], [0, 0]) # Dummy fit
42+
dist1 = mixin.score_pairs([[u, v], [v, u]])
43+
u_v = np.dot(np.dot(u.T, np.identity(d)), v)
44+
v_u = np.dot(np.dot(v.T, np.identity(d)), u)
45+
desired = [u_v, v_u]
46+
assert_array_almost_equal(dist1, desired)
47+
48+
def test_check_handmade_example():
49+
u = np.array([0, 1, 2])
50+
v = np.array([3, 4, 5])
51+
mixin = IdentityBilinearMixin()
52+
mixin.fit([u, v], [0, 0])
53+
c = np.array([[2, 4, 6], [6, 4, 2], [1, 2, 3]])
54+
mixin.components_ = c # Force a components_
55+
dists = mixin.score_pairs([[u, v], [v, u]])
56+
assert_array_almost_equal(dists, [96, 120])
57+
58+
def test_check_handmade_symmetric_example():
59+
u = np.array([0, 1, 2])
60+
v = np.array([3, 4, 5])
61+
mixin = IdentityBilinearMixin()
62+
mixin.fit([u, v], [0, 0])
63+
dists = mixin.score_pairs([[u, v], [v, u]])
64+
assert_array_almost_equal(dists, [14, 14])

test_bilinear.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)