Skip to content

Commit cc88cbe

Browse files
Replace 'metric' parameter with a better name
Fixes #54.
1 parent 70c48fd commit cc88cbe

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

metric_learn/lfda.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class LFDA(BaseMetricLearner):
2626
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
2727
Sugiyama, ICML 2006
2828
'''
29-
def __init__(self, num_dims=None, k=None, metric='weighted'):
29+
def __init__(self, num_dims=None, k=None, embedding_type='weighted'):
3030
'''
3131
Initialize LFDA.
3232
@@ -39,16 +39,16 @@ def __init__(self, num_dims=None, k=None, metric='weighted'):
3939
Number of nearest neighbors used in local scaling method.
4040
Defaults to min(7, num_dims - 1).
4141
42-
metric : str, optional
42+
embedding_type : str, optional
4343
Type of metric in the embedding space (default: 'weighted')
4444
'weighted' - weighted eigenvectors
4545
'orthonormalized' - orthonormalized
4646
'plain' - raw eigenvectors
4747
'''
48-
if metric not in ('weighted', 'orthonormalized', 'plain'):
49-
raise ValueError('Invalid metric: %r' % metric)
48+
if embedding_type not in ('weighted', 'orthonormalized', 'plain'):
49+
raise ValueError('Invalid embedding_type: %r' % embedding_type)
5050
self.num_dims = num_dims
51-
self.metric = metric
51+
self.embedding_type = embedding_type
5252
self.k = k
5353

5454
def transformer(self):
@@ -122,9 +122,9 @@ def fit(self, X, y):
122122
vals = vals[order].real
123123
vecs = vecs[:,order]
124124

125-
if self.metric == 'weighted':
125+
if self.embedding_type == 'weighted':
126126
vecs *= np.sqrt(vals)
127-
elif self.metric == 'orthonormalized':
127+
elif self.embedding_type == 'orthonormalized':
128128
vecs, _ = np.linalg.qr(vecs)
129129

130130
self.transformer_ = vecs.T

test/metric_learn_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def test_iris(self):
112112
csep = class_separation(lfda.transform(), self.iris_labels)
113113
self.assertLess(csep, 0.15)
114114

115+
# Sanity checks for learned matrices.
116+
self.assertEqual(lfda.metric().shape, (4, 4))
117+
self.assertEqual(lfda.transformer().shape, (2, 4))
118+
115119

116120
class TestRCA(MetricTestCase):
117121
def test_iris(self):

test/test_base_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_nca(self):
2020

2121
def test_lfda(self):
2222
self.assertEqual(str(metric_learn.LFDA()),
23-
"LFDA(k=None, metric='weighted', num_dims=None)")
23+
"LFDA(embedding_type='weighted', k=None, num_dims=None)")
2424

2525
def test_itml(self):
2626
self.assertEqual(str(metric_learn.ITML()), """

0 commit comments

Comments
 (0)