Skip to content

Commit 80c9085

Browse files
author
mvargas33
committed
All tests have been generalized
1 parent 407f910 commit 80c9085

File tree

1 file changed

+53
-47
lines changed

1 file changed

+53
-47
lines changed

test/test_bilinear_mixin.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import pytest
66
from metric_learn._util import make_context
77
from sklearn.cluster import DBSCAN
8+
from sklearn.datasets import make_spd_matrix
9+
from sklearn.utils import check_random_state
810

11+
RNG = check_random_state(0)
912

1013
class IdentityBilinearMixin(BilinearMixin):
1114
"""A simple Identity bilinear mixin that returns an identity matrix
@@ -15,14 +18,17 @@ class IdentityBilinearMixin(BilinearMixin):
1518
def __init__(self, preprocessor=None):
1619
super().__init__(preprocessor=preprocessor)
1720

18-
def fit(self, X, y):
21+
def fit(self, X, y, random=False):
1922
"""
2023
Checks input's format. Sets M matrix to identity of shape (d,d)
2124
where d is the dimension of the input.
2225
"""
2326
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
2427
self.d = np.shape(X[0])[-1]
25-
self.components_ = np.identity(self.d)
28+
if random:
29+
self.components_ = np.random.rand(self.d, self.d)
30+
else:
31+
self.components_ = np.identity(self.d)
2632
return self
2733

2834
def random_M(self):
@@ -32,29 +38,34 @@ def random_M(self):
3238
self.components_ = np.random.rand(self.d, self.d)
3339

3440

35-
def identity_fit(d=100):
41+
def identity_fit(d=100, n=100, n_pairs=None, random=False):
3642
"""
37-
Creates two d-dimentional arrays. Fits an IdentityBilinearMixin()
38-
and then returns the two arrays and the mixin. Testing purposes
43+
Creates 'n' d-dimentional arrays. Also generates 'n_pairs'
44+
sampled from the 'n' arrays. Fits an IdentityBilinearMixin()
45+
and then returns the arrays, the pairs and the mixin. Only
46+
generates the pairs if n_pairs is not None
3947
"""
40-
d = 100
41-
u = np.random.rand(d)
42-
v = np.random.rand(d)
48+
X = np.array([np.random.rand(d) for _ in range(n)])
4349
mixin = IdentityBilinearMixin()
44-
mixin.fit([u, v], [0, 0])
45-
return u, v, mixin
50+
mixin.fit(X, [0 for _ in range(n)], random=random)
51+
if n_pairs is not None:
52+
random_pairs = [[X[RNG.randint(0, n)], X[RNG.randint(0, n)]]
53+
for _ in range(n_pairs)]
54+
else:
55+
random_pairs = None
56+
return X, random_pairs, mixin
4657

4758

4859
def test_same_similarity_with_two_methods():
4960
""""
5061
Tests that score_pairs() and get_metric() give consistent results.
5162
In both cases, the results must match for the same input.
63+
Tests it for 'n_pairs' sampled from 'n' d-dimentional arrays.
5264
"""
53-
u, v, mixin = identity_fit()
54-
mixin.random_M() # Dummy fit
55-
# The distances must match, whether calc with get_metric() or score_pairs()
56-
dist1 = mixin.score_pairs([[u, v], [v, u]])
57-
dist2 = [mixin.get_metric()(u, v), mixin.get_metric()(v, u)]
65+
d, n, n_pairs= 100, 100, 1000
66+
_, random_pairs, mixin = identity_fit(d=d, n=n, n_pairs=n_pairs, random=True)
67+
dist1 = mixin.score_pairs(random_pairs)
68+
dist2 = [mixin.get_metric()(p[0], p[1]) for p in random_pairs]
5869

5970
assert_array_almost_equal(dist1, dist2)
6071

@@ -65,14 +76,12 @@ def test_check_correctness_similarity():
6576
get_metric(). Results are compared with the real bilinear similarity
6677
calculated in-place.
6778
"""
68-
d = 100
69-
u, v, mixin = identity_fit(d)
70-
dist1 = mixin.score_pairs([[u, v], [v, u]])
71-
dist2 = [mixin.get_metric()(u, v), mixin.get_metric()(v, u)]
72-
73-
u_v = np.dot(np.dot(u.T, np.identity(d)), v)
74-
v_u = np.dot(np.dot(v.T, np.identity(d)), u)
75-
desired = [u_v, v_u]
79+
d, n, n_pairs= 100, 100, 1000
80+
_, random_pairs, mixin = identity_fit(d=d, n=n, n_pairs=n_pairs, random=True)
81+
dist1 = mixin.score_pairs(random_pairs)
82+
dist2 = [mixin.get_metric()(p[0], p[1]) for p in random_pairs]
83+
desired = [np.dot(np.dot(p[0].T, mixin.components_), p[1]) for p in random_pairs]
84+
7685
assert_array_almost_equal(dist1, desired) # score_pairs
7786
assert_array_almost_equal(dist2, desired) # get_metric
7887

@@ -98,27 +107,31 @@ def test_check_handmade_symmetric_example():
98107
between two arrays must be equal: S(u,v) = S(v,u). Also
99108
checks the random case: when the matrix is pd and symetric.
100109
"""
101-
u = np.array([0, 1, 2])
102-
v = np.array([3, 4, 5])
103-
mixin = IdentityBilinearMixin()
104-
mixin.fit([u, v], [0, 0]) # Identity fit
105-
dists = mixin.score_pairs([[u, v], [v, u]])
106-
assert_array_almost_equal(dists, [14, 14])
110+
# Random pairs for M = Identity
111+
d, n, n_pairs= 100, 100, 1000
112+
_, random_pairs, mixin = identity_fit(d=d, n=n, n_pairs=n_pairs)
113+
pairs_reverse = [[p[1], p[0]] for p in random_pairs]
114+
dist1 = mixin.score_pairs(random_pairs)
115+
dist2 = mixin.score_pairs(pairs_reverse)
116+
assert_array_almost_equal(dist1, dist2)
107117

118+
# Random pairs for M = spd Matrix
119+
spd_matrix = make_spd_matrix(d, random_state=RNG)
120+
mixin.components_ = spd_matrix
121+
dist1 = mixin.score_pairs(random_pairs)
122+
dist2 = mixin.score_pairs(pairs_reverse)
123+
assert_array_almost_equal(dist1, dist2)
108124

109125
def test_score_pairs_finite():
110126
"""
111127
Checks for 'n' score_pairs() of 'd' dimentions, that all
112128
similarities are finite numbers, not NaN, +inf or -inf.
113129
Considering a random M for bilinear similarity.
114130
"""
115-
d = 100
116-
u, v, mixin = identity_fit(d)
117-
mixin.random_M() # Dummy fit
118-
n = 100
119-
X = np.array([np.random.rand(d) for i in range(n)])
120-
pairs = np.array(list(product(X, X)))
121-
assert np.isfinite(mixin.score_pairs(pairs)).all()
131+
d, n, n_pairs= 100, 100, 1000
132+
_, random_pairs, mixin = identity_fit(d=d, n=n, n_pairs=n_pairs, random=True)
133+
dist1 = mixin.score_pairs(random_pairs)
134+
assert np.isfinite(dist1).all()
122135

123136

124137
def test_score_pairs_dim():
@@ -127,11 +140,8 @@ def test_score_pairs_dim():
127140
and scoring of 2D arrays (one tuple) should return an error (like
128141
scikit-learn's error when scoring 1D arrays)
129142
"""
130-
d = 100
131-
u, v, mixin = identity_fit()
132-
mixin.random_M() # Dummy fit
133-
n = 100
134-
X = np.array([np.random.rand(d) for i in range(n)])
143+
d, n, n_pairs= 100, 100, 1000
144+
X, _, mixin = identity_fit(d=d, n=n, n_pairs=None, random=True)
135145
tuples = np.array(list(product(X, X)))
136146
assert mixin.score_pairs(tuples).shape == (tuples.shape[0],)
137147
context = make_context(mixin)
@@ -146,11 +156,7 @@ def test_score_pairs_dim():
146156
def test_check_scikitlearn_compatibility():
147157
"""Check that the similarity returned by get_metric() is compatible with
148158
scikit-learn's algorithms using a custom metric, DBSCAN for instance"""
149-
d = 100
150-
u, v, mixin = identity_fit(d)
151-
mixin.random_M() # Dummy fit
152-
153-
n = 100
154-
X = np.array([np.random.rand(d) for i in range(n)])
159+
d, n= 100, 100
160+
X, _, mixin = identity_fit(d=d, n=n, n_pairs=None, random=True)
155161
clustering = DBSCAN(metric=mixin.get_metric())
156162
clustering.fit(X)

0 commit comments

Comments
 (0)