Skip to content

Commit 4555105

Browse files
author
Björn Barz
committed
Added unit test for transformer-metric conversion
1 parent 9549549 commit 4555105

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
import numpy as np
3+
from sklearn.datasets import load_iris
4+
from numpy.testing import assert_array_almost_equal
5+
6+
from metric_learn import (
7+
LMNN, NCA, LFDA, Covariance, MLKR,
8+
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
9+
10+
11+
class TestTransformerMetricConversion(unittest.TestCase):
12+
@classmethod
13+
def setUpClass(self):
14+
# runs once per test class
15+
iris_data = load_iris()
16+
self.X = iris_data['data']
17+
self.y = iris_data['target']
18+
19+
def test_cov(self):
20+
cov = Covariance()
21+
cov.fit(self.X)
22+
L = cov.transformer()
23+
assert_array_almost_equal(L.T.dot(L), cov.metric())
24+
25+
def test_lsml_supervised(self):
26+
seed = np.random.RandomState(1234)
27+
lsml = LSML_Supervised(num_constraints=200)
28+
lsml.fit(self.X, self.y, random_state=seed)
29+
L = lsml.transformer()
30+
assert_array_almost_equal(L.T.dot(L), lsml.metric())
31+
32+
def test_itml_supervised(self):
33+
seed = np.random.RandomState(1234)
34+
itml = ITML_Supervised(num_constraints=200)
35+
itml.fit(self.X, self.y, random_state=seed)
36+
L = itml.transformer()
37+
assert_array_almost_equal(L.T.dot(L), itml.metric())
38+
39+
def test_lmnn(self):
40+
lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False)
41+
lmnn.fit(self.X, self.y)
42+
L = lmnn.transformer()
43+
assert_array_almost_equal(L.T.dot(L), lmnn.metric())
44+
45+
def test_sdml_supervised(self):
46+
seed = np.random.RandomState(1234)
47+
sdml = SDML_Supervised(num_constraints=1500)
48+
sdml.fit(self.X, self.y, random_state=seed)
49+
L = sdml.transformer()
50+
assert_array_almost_equal(L.T.dot(L), sdml.metric())
51+
52+
def test_nca(self):
53+
n = self.X.shape[0]
54+
nca = NCA(max_iter=(100000//n), learning_rate=0.01)
55+
nca.fit(self.X, self.y)
56+
L = nca.transformer()
57+
assert_array_almost_equal(L.T.dot(L), nca.metric())
58+
59+
def test_lfda(self):
60+
lfda = LFDA(k=2, num_dims=2)
61+
lfda.fit(self.X, self.y)
62+
L = lfda.transformer()
63+
assert_array_almost_equal(L.T.dot(L), lfda.metric())
64+
65+
def test_rca_supervised(self):
66+
seed = np.random.RandomState(1234)
67+
rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2)
68+
rca.fit(self.X, self.y, random_state=seed)
69+
L = rca.transformer()
70+
assert_array_almost_equal(L.T.dot(L), rca.metric())
71+
72+
def test_mlkr(self):
73+
mlkr = MLKR(num_dims=2)
74+
mlkr.fit(self.X, self.y)
75+
L = mlkr.transformer()
76+
assert_array_almost_equal(L.T.dot(L), mlkr.metric())
77+
78+
79+
if __name__ == '__main__':
80+
unittest.main()

0 commit comments

Comments
 (0)