Skip to content

Commit c5087d7

Browse files
bhargavvaderperimosocordiae
authored andcommitted
[MRG] Adding fit_transform (#26)
1 parent 890b6a8 commit c5087d7

File tree

5 files changed

+157
-3
lines changed

5 files changed

+157
-3
lines changed

metric_learn/base_metric.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ def transform(self, X=None):
4545
X = self.X
4646
L = self.transformer()
4747
return X.dot(L.T)
48+
49+
def fit_transform(self, *args, **kwargs):
50+
"""
51+
Function calls .fit() and returns the result of .transform()
52+
Essentially, it runs the relevant Metric Learning algorithm with .fit()
53+
and returns the metric-transformed input data.
54+
55+
Paramters
56+
---------
57+
58+
Since all the parameters passed to fit_transform are passed on to
59+
fit(), the parameters to be passed must be noted from the corresponding
60+
Metric Learning algorithm's fit method.
61+
62+
Returns
63+
-------
64+
transformed : (n x d) matrix
65+
Input data transformed to the metric space by :math:`XL^{\\top}`
66+
67+
"""
68+
self.fit(*args, **kwargs)
69+
return self.transform()
4870

4971
def get_params(self, deep=False):
5072
"""Get parameters for this metric learner.

metric_learn/itml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,5 @@ def fit(self, X, labels, random_state=np.random):
183183
num_constraints = 20*(len(num_classes))**2
184184

185185
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
186-
return ITML.fit(self, X, c.positive_negative_pairs(num_constraints),
186+
return ITML.fit(self, X, c.positive_negative_pairs(num_constraints, random_state=random_state),
187187
bounds=self.params['bounds'], A0=self.params['A0'])

metric_learn/lsml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,6 @@ def fit(self, X, labels, random_state=np.random):
172172
num_constraints = 20*(len(num_classes))**2
173173

174174
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
175-
pairs = c.positive_negative_pairs(num_constraints, same_length=True)
175+
pairs = c.positive_negative_pairs(num_constraints, same_length=True, random_state=random_state)
176176
return LSML.fit(self, X, pairs, weights=self.params['weights'],
177177
prior=self.params['prior'])

metric_learn/sdml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,4 @@ def fit(self, X, labels, random_state=np.random):
106106
num_constraints = 20*(len(num_classes))**2
107107

108108
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
109-
return SDML.fit(self, X, c.adjacency_matrix(num_constraints))
109+
return SDML.fit(self, X, c.adjacency_matrix(num_constraints, random_state=random_state))

test/test_fit_transform.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import unittest
2+
import numpy as np
3+
from sklearn.datasets import load_iris
4+
from numpy.testing import assert_array_almost_equal
5+
6+
from metric_learn import (
7+
LMNN, NCA, LFDA, Covariance,
8+
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
9+
10+
11+
12+
class MetricTestCase(unittest.TestCase):
13+
@classmethod
14+
def setUpClass(self):
15+
# runs once per test class
16+
iris_data = load_iris()
17+
self.iris_points = iris_data['data']
18+
self.iris_labels = iris_data['target']
19+
20+
21+
class TestCovariance(MetricTestCase):
22+
def test_cov(self):
23+
cov = Covariance()
24+
cov.fit(self.iris_points)
25+
res_1 = cov.transform()
26+
27+
cov = Covariance()
28+
res_2 = cov.fit_transform(self.iris_points)
29+
# deterministic result
30+
assert_array_almost_equal(res_1, res_2)
31+
32+
33+
class TestLSML(MetricTestCase):
34+
def test_lsml(self):
35+
36+
seed = np.random.RandomState(1234)
37+
lsml = LSML_Supervised(num_constraints=200)
38+
lsml.fit(self.iris_points, self.iris_labels, random_state=seed)
39+
res_1 = lsml.transform()
40+
41+
seed = np.random.RandomState(1234)
42+
lsml = LSML_Supervised(num_constraints=200)
43+
res_2 = lsml.fit_transform(self.iris_points, self.iris_labels, random_state=seed)
44+
45+
assert_array_almost_equal(res_1, res_2)
46+
47+
class TestITML(MetricTestCase):
48+
def test_itml(self):
49+
50+
seed = np.random.RandomState(1234)
51+
itml = ITML_Supervised(num_constraints=200)
52+
itml.fit(self.iris_points, self.iris_labels, random_state=seed)
53+
res_1 = itml.transform()
54+
55+
seed = np.random.RandomState(1234)
56+
itml = ITML_Supervised(num_constraints=200)
57+
res_2 = itml.fit_transform(self.iris_points, self.iris_labels, random_state=seed)
58+
59+
assert_array_almost_equal(res_1, res_2)
60+
61+
class TestLMNN(MetricTestCase):
62+
def test_lmnn(self):
63+
64+
lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False)
65+
lmnn.fit(self.iris_points, self.iris_labels)
66+
res_1 = lmnn.transform()
67+
68+
lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False)
69+
res_2 = lmnn.fit_transform(self.iris_points, self.iris_labels)
70+
71+
assert_array_almost_equal(res_1, res_2)
72+
73+
class TestSDML(MetricTestCase):
74+
def test_sdml(self):
75+
76+
seed = np.random.RandomState(1234)
77+
sdml = SDML_Supervised(num_constraints=1500)
78+
sdml.fit(self.iris_points, self.iris_labels, random_state=seed)
79+
res_1 = sdml.transform()
80+
81+
seed = np.random.RandomState(1234)
82+
sdml = SDML_Supervised(num_constraints=1500)
83+
res_2 = sdml.fit_transform(self.iris_points, self.iris_labels, random_state=seed)
84+
85+
assert_array_almost_equal(res_1, res_2)
86+
87+
class TestNCA(MetricTestCase):
88+
def test_nca(self):
89+
90+
n = self.iris_points.shape[0]
91+
nca = NCA(max_iter=(100000//n), learning_rate=0.01)
92+
nca.fit(self.iris_points, self.iris_labels)
93+
res_1 = nca.transform()
94+
95+
nca = NCA(max_iter=(100000//n), learning_rate=0.01)
96+
res_2 = nca.fit_transform(self.iris_points, self.iris_labels)
97+
98+
assert_array_almost_equal(res_1, res_2)
99+
100+
class TestLFDA(MetricTestCase):
101+
def test_lfda(self):
102+
103+
lfda = LFDA(k=2, dim=2)
104+
lfda.fit(self.iris_points, self.iris_labels)
105+
res_1 = lfda.transform()
106+
107+
lfda = LFDA(k=2, dim=2)
108+
res_2 = lfda.fit_transform(self.iris_points, self.iris_labels)
109+
110+
res_1 = round(res_1[0][0], 3)
111+
res_2 = round(res_2[0][0], 3)
112+
res = (res_1 == res_2 or res_1 == -res_2)
113+
114+
self.assertTrue(res)
115+
116+
class TestRCA(MetricTestCase):
117+
def test_rca(self):
118+
119+
seed = np.random.RandomState(1234)
120+
rca = RCA_Supervised(dim=2, num_chunks=30, chunk_size=2)
121+
rca.fit(self.iris_points, self.iris_labels, random_state=seed)
122+
res_1 = rca.transform()
123+
124+
seed = np.random.RandomState(1234)
125+
rca = RCA_Supervised(dim=2, num_chunks=30, chunk_size=2)
126+
res_2 = rca.fit_transform(self.iris_points, self.iris_labels, random_state=seed)
127+
128+
assert_array_almost_equal(res_1, res_2)
129+
130+
131+
if __name__ == '__main__':
132+
unittest.main()

0 commit comments

Comments
 (0)