|
| 1 | +import unittest |
| 2 | +import numpy as np |
| 3 | +from sklearn.datasets import load_iris |
| 4 | +from numpy.testing import assert_array_almost_equal |
| 5 | + |
| 6 | +from metric_learn import ( |
| 7 | + LMNN, NCA, LFDA, Covariance, MLKR, |
| 8 | + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) |
| 9 | + |
| 10 | + |
| 11 | +class TestTransformerMetricConversion(unittest.TestCase): |
| 12 | + @classmethod |
| 13 | + def setUpClass(self): |
| 14 | + # runs once per test class |
| 15 | + iris_data = load_iris() |
| 16 | + self.X = iris_data['data'] |
| 17 | + self.y = iris_data['target'] |
| 18 | + |
| 19 | + def test_cov(self): |
| 20 | + cov = Covariance() |
| 21 | + cov.fit(self.X) |
| 22 | + L = cov.transformer() |
| 23 | + assert_array_almost_equal(L.T.dot(L), cov.metric()) |
| 24 | + |
| 25 | + def test_lsml_supervised(self): |
| 26 | + seed = np.random.RandomState(1234) |
| 27 | + lsml = LSML_Supervised(num_constraints=200) |
| 28 | + lsml.fit(self.X, self.y, random_state=seed) |
| 29 | + L = lsml.transformer() |
| 30 | + assert_array_almost_equal(L.T.dot(L), lsml.metric()) |
| 31 | + |
| 32 | + def test_itml_supervised(self): |
| 33 | + seed = np.random.RandomState(1234) |
| 34 | + itml = ITML_Supervised(num_constraints=200) |
| 35 | + itml.fit(self.X, self.y, random_state=seed) |
| 36 | + L = itml.transformer() |
| 37 | + assert_array_almost_equal(L.T.dot(L), itml.metric()) |
| 38 | + |
| 39 | + def test_lmnn(self): |
| 40 | + lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) |
| 41 | + lmnn.fit(self.X, self.y) |
| 42 | + L = lmnn.transformer() |
| 43 | + assert_array_almost_equal(L.T.dot(L), lmnn.metric()) |
| 44 | + |
| 45 | + def test_sdml_supervised(self): |
| 46 | + seed = np.random.RandomState(1234) |
| 47 | + sdml = SDML_Supervised(num_constraints=1500) |
| 48 | + sdml.fit(self.X, self.y, random_state=seed) |
| 49 | + L = sdml.transformer() |
| 50 | + assert_array_almost_equal(L.T.dot(L), sdml.metric()) |
| 51 | + |
| 52 | + def test_nca(self): |
| 53 | + n = self.X.shape[0] |
| 54 | + nca = NCA(max_iter=(100000//n), learning_rate=0.01) |
| 55 | + nca.fit(self.X, self.y) |
| 56 | + L = nca.transformer() |
| 57 | + assert_array_almost_equal(L.T.dot(L), nca.metric()) |
| 58 | + |
| 59 | + def test_lfda(self): |
| 60 | + lfda = LFDA(k=2, num_dims=2) |
| 61 | + lfda.fit(self.X, self.y) |
| 62 | + L = lfda.transformer() |
| 63 | + assert_array_almost_equal(L.T.dot(L), lfda.metric()) |
| 64 | + |
| 65 | + def test_rca_supervised(self): |
| 66 | + seed = np.random.RandomState(1234) |
| 67 | + rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) |
| 68 | + rca.fit(self.X, self.y, random_state=seed) |
| 69 | + L = rca.transformer() |
| 70 | + assert_array_almost_equal(L.T.dot(L), rca.metric()) |
| 71 | + |
| 72 | + def test_mlkr(self): |
| 73 | + mlkr = MLKR(num_dims=2) |
| 74 | + mlkr.fit(self.X, self.y) |
| 75 | + L = mlkr.transformer() |
| 76 | + assert_array_almost_equal(L.T.dot(L), mlkr.metric()) |
| 77 | + |
| 78 | + |
| 79 | +if __name__ == '__main__': |
| 80 | + unittest.main() |
0 commit comments