Skip to content

Commit e47a39a

Browse files
adding RCA method
1 parent 11945d1 commit e47a39a

File tree

4 files changed

+93
-1
lines changed

4 files changed

+93
-1
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Metric Learning algorithms in Python.
1313
- Least Squares Metric Learning (LSML)
1414
- Neighborhood Components Analysis (NCA)
1515
- Local Fisher Discriminant Analysis (LFDA)
16+
- Relative Components Analysis (RCA)
1617

1718
**Dependencies**
1819

metric_learn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from sdml import SDML
55
from nca import NCA
66
from lfda import LFDA
7+
from rca import RCA

metric_learn/rca.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
from base_metric import BaseMetricLearner
3+
4+
5+
class RCA(BaseMetricLearner):
6+
'''Relevant Components Analysis (RCA)
7+
'Learning distance functions using equivalence relations', ICML 2003
8+
'''
9+
10+
def __init__(self, dim=None):
11+
'''
12+
dim : embedding dimension (default: original dimension of data)
13+
'''
14+
self.dim = dim
15+
16+
def transformer(self):
17+
return self._transformer
18+
19+
def _process_inputs(self, X, Y):
20+
X = np.asanyarray(X)
21+
self.X = X
22+
n, d = X.shape
23+
24+
if self.dim is None:
25+
self.dim = d
26+
elif not 0 < self.dim <= d:
27+
raise ValueError('Invalid embedding dimension, must be in [1,%d]' % d)
28+
29+
Y = np.asanyarray(Y)
30+
num_chunks = Y.max() + 1
31+
32+
return X, Y, num_chunks, d
33+
34+
def fit(self, data, chunks):
35+
'''
36+
data : (n,d) array-like, input data
37+
chunks : (n,) array-like
38+
chunks[i] == -1 -> point i doesn't belong to any chunklet
39+
chunks[i] == j -> point i belongs to chunklet j
40+
'''
41+
data, chunks, num_chunks, d = self._process_inputs(data, chunks)
42+
43+
# mean center
44+
data -= data.mean(axis=0)
45+
46+
# mean center each chunklet separately
47+
chunk_mask = chunks != -1
48+
chunk_data = data[chunk_mask]
49+
chunk_labels = chunks[chunk_mask]
50+
for c in xrange(num_chunks):
51+
mask = chunk_labels == c
52+
chunk_data[mask] -= chunk_data[mask].mean(axis=0)
53+
54+
# "inner" covariance of chunk deviations
55+
inner_cov = np.cov(chunk_data, rowvar=0, bias=1)
56+
57+
# Fisher Linear Discriminant projection
58+
if self.dim < d:
59+
total_cov = np.cov(data[chunk_mask], rowvar=0)
60+
tmp = np.linalg.lstsq(total_cov, inner_cov)[0]
61+
vals, vecs = np.linalg.eig(tmp)
62+
inds = np.argsort(vals)[:self.dim]
63+
A = vecs[:,inds]
64+
inner_cov = A.T.dot(inner_cov).dot(A)
65+
self._transformer = _inv_sqrtm(inner_cov).dot(A.T)
66+
else:
67+
self._transformer = _inv_sqrtm(inner_cov).T
68+
69+
70+
def _inv_sqrtm(x):
71+
'''Computes x^(-1/2)'''
72+
vals, vecs = np.linalg.eigh(x)
73+
return (vecs / np.sqrt(vals)).dot(vecs.T)

test/metric_learn_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.datasets import load_iris
66
from numpy.testing import assert_array_almost_equal
77

8-
from metric_learn import LSML, ITML, LMNN, SDML, NCA, LFDA
8+
from metric_learn import LSML, ITML, LMNN, SDML, NCA, LFDA, RCA
99
# Import this specially for testing.
1010
from metric_learn.lmnn import python_LMNN
1111

@@ -103,5 +103,22 @@ def test_iris(self):
103103
csep = class_separation(lfda.transform(), self.iris_labels)
104104
self.assertLess(csep, 0.15)
105105

106+
107+
class TestRCA(MetricTestCase):
108+
def test_iris(self):
109+
rca = RCA(dim=2)
110+
chunks = self.iris_labels.copy()
111+
a, = np.where(chunks==0)
112+
b, = np.where(chunks==1)
113+
c, = np.where(chunks==2)
114+
chunks[:] = -1
115+
chunks[a[:20]] = np.repeat(np.arange(10), 2)
116+
chunks[b[:20]] = np.repeat(np.arange(10, 20), 2)
117+
chunks[c[:20]] = np.repeat(np.arange(20, 30), 2)
118+
rca.fit(self.iris_points, chunks)
119+
csep = class_separation(rca.transform(), self.iris_labels)
120+
self.assertLess(csep, 0.25)
121+
122+
106123
if __name__ == '__main__':
107124
unittest.main()

0 commit comments

Comments
 (0)