Skip to content

Commit a402857

Browse files
Adding LFDA method
Iris results seem to match the figure in the original paper.
1 parent 46a0711 commit a402857

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Metric Learning algorithms in Python.
1212
- Sparse Determinant Metric Learning (SDML)
1313
- Least Squares Metric Learning (LSML)
1414
- Neighborhood Components Analysis (NCA)
15+
- Local Fisher Discriminant Analysis (LFDA)
1516

1617
**Dependencies**
1718

metric_learn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from lsml import LSML
44
from sdml import SDML
55
from nca import NCA
6+
from lfda import LFDA

metric_learn/lfda.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import division
2+
import numpy as np
3+
import scipy
4+
from sklearn.metrics import pairwise_distances
5+
from base_metric import BaseMetricLearner
6+
7+
8+
class LFDA(BaseMetricLearner):
9+
'''
10+
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
11+
Sugiyama, ICML 2006
12+
'''
13+
def __init__(self, dim=None, k=7, metric='weighted'):
14+
'''
15+
dim : dimensionality of reduced space (defaults to dimension of X)
16+
k : nearest neighbor used in local scaling method (default: 7)
17+
metric : type of metric in the embedding space (default: 'weighted')
18+
'weighted' - weighted eigenvectors
19+
'orthonormalized' - orthonormalized
20+
'plain' - raw eigenvectors
21+
'''
22+
if metric not in ('weighted', 'orthonormalized', 'plain'):
23+
raise ValueError('Invalid metric: %r' % metric)
24+
self.dim = dim
25+
self.metric = metric
26+
self.k = k
27+
28+
def transformer(self):
29+
return self._tranformer
30+
31+
def _process_inputs(self, X, Y):
32+
X = np.asanyarray(X)
33+
self.X = X
34+
n, d = X.shape
35+
unique_classes, Y = np.unique(Y, return_inverse=True)
36+
num_classes = len(unique_classes)
37+
38+
if self.dim is None:
39+
self.dim = d
40+
elif not 0 < self.dim <= d:
41+
raise ValueError('Invalid embedding dimension, must be in [1,%d]' % d)
42+
43+
if not 0 < self.k < d:
44+
raise ValueError('Invalid k, must be in [0,%d]' % (d-1))
45+
46+
return X, Y, num_classes, n, d
47+
48+
def fit(self, X, Y):
49+
'''
50+
X: (n, d) array-like of samples
51+
Y: (n,) array-like of class labels
52+
'''
53+
X, Y, num_classes, n, d = self._process_inputs(X, Y)
54+
tSb = np.zeros((d,d))
55+
tSw = np.zeros((d,d))
56+
57+
for c in xrange(num_classes):
58+
Xc = X[Y==c]
59+
nc = Xc.shape[0]
60+
61+
# classwise affinity matrix
62+
dist = pairwise_distances(Xc, metric='l2', squared=True)
63+
# distances to k-th nearest neighbor
64+
k = min(self.k, nc-1)
65+
sigma = np.sqrt(np.partition(dist, k, axis=0)[:,k])
66+
67+
local_scale = np.outer(sigma, sigma)
68+
with np.errstate(divide='ignore', invalid='ignore'):
69+
A = np.exp(-dist/local_scale)
70+
A[local_scale==0] = 0
71+
72+
G = Xc.T.dot(A.sum(axis=0)[:,None] * Xc) - Xc.T.dot(A).dot(Xc)
73+
tSb += G/n + (1-nc/n)*Xc.T.dot(Xc) + _sum_outer(Xc)/n
74+
tSw += G/nc
75+
76+
tSb -= _sum_outer(X)/n - tSw
77+
78+
# symmetrize
79+
tSb += tSb.T
80+
tSb /= 2
81+
tSw += tSw.T
82+
tSw /= 2
83+
84+
if self.dim == d:
85+
vals, vecs = scipy.linalg.eigh(tSb, tSw)
86+
else:
87+
vals, vecs = scipy.sparse.linalg.eigsh(tSb, k=self.dim, M=tSw, which='LA')
88+
89+
order = np.argsort(-vals)[:self.dim]
90+
vals = vals[order]
91+
vecs = vecs[:,order]
92+
93+
if self.metric == 'weighted':
94+
vecs *= np.sqrt(vals)
95+
elif self.metric == 'orthonormalized':
96+
vecs, _ = np.linalg.qr(vecs)
97+
98+
self._tranformer = vecs.T
99+
100+
101+
def _sum_outer(x):
102+
s = x.sum(axis=0)
103+
return np.outer(s, s)

test/metric_learn_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.datasets import load_iris
66
from numpy.testing import assert_array_almost_equal
77

8-
from metric_learn import LSML, ITML, LMNN, SDML, NCA
8+
from metric_learn import LSML, ITML, LMNN, SDML, NCA, LFDA
99
# Import this specially for testing.
1010
from metric_learn.lmnn import python_LMNN
1111

@@ -96,5 +96,12 @@ def test_iris(self):
9696
assert_array_almost_equal(expected, nca.transformer(), decimal=3)
9797

9898

99+
class TestLFDA(MetricTestCase):
100+
def test_iris(self):
101+
lfda = LFDA(k=2, dim=2)
102+
lfda.fit(self.iris_points, self.iris_labels)
103+
csep = class_separation(lfda.transform(), self.iris_labels)
104+
self.assertLess(csep, 0.15)
105+
99106
if __name__ == '__main__':
100107
unittest.main()

0 commit comments

Comments
 (0)