Skip to content

Commit 75d4ad2

Browse files
author
William de Vazelhes
committed
Merge branch 'feat/mahalanobis_class' of https://github.com/wdevazelhes/metric-learn into feat/mahalanobis_class
2 parents 779a93a + 4dd8990 commit 75d4ad2

15 files changed

+159
-33
lines changed

metric_learn/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def check_tuples(tuples):
2828
The validated input.
2929
"""
3030
# If input is scalar raise error
31-
if len(tuples.shape) == 0:
31+
if np.isscalar(tuples):
3232
raise ValueError(
3333
"Expected 3D array, got scalar instead. Cannot apply this function on "
3434
"scalars.")

metric_learn/base_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def score_pairs(self, pairs):
2929
"""
3030

3131

32-
class MetricTransformer():
32+
class MetricTransformer(object):
3333

3434
@abstractmethod
3535
def transform(self, X):
@@ -63,8 +63,8 @@ class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
6363
6464
Attributes
6565
----------
66-
transformer_ : `np.ndarray`, shape=(num_dims, n_features)
67-
The learned linear transformation ``L``.
66+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
67+
The learned linear transformation ``L``.
6868
"""
6969

7070
def score_pairs(self, pairs):
@@ -117,7 +117,7 @@ def transform(self, X):
117117
def metric(self):
118118
return self.transformer_.T.dot(self.transformer_)
119119

120-
def _transformer_from_metric(self, metric):
120+
def transformer_from_metric(self, metric):
121121
"""Computes the transformation matrix from the Mahalanobis matrix.
122122
123123
Since by definition the metric `M` is positive semi-definite (PSD), it
@@ -257,4 +257,4 @@ def score(self, quadruplets, y=None):
257257
The quadruplets score.
258258
"""
259259
quadruplets = check_tuples(quadruplets)
260-
return - np.mean(self.predict(quadruplets))
260+
return -np.mean(self.predict(quadruplets))

metric_learn/covariance.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717

1818

1919
class Covariance(MahalanobisMixin, TransformerMixin):
20+
"""Covariance metric (baseline method)
21+
22+
Attributes
23+
----------
24+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
25+
The linear transformation ``L`` deduced from the learned Mahalanobis
26+
metric (See :meth:`transformer_from_metric`.)
27+
"""
28+
2029
def __init__(self):
2130
pass
2231

@@ -32,5 +41,5 @@ def fit(self, X, y=None):
3241
else:
3342
self.M_ = np.linalg.inv(self.M_)
3443

35-
self.transformer_ = self._transformer_from_metric(check_array(self.M_))
44+
self.transformer_ = self.transformer_from_metric(check_array(self.M_))
3645
return self

metric_learn/itml.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,19 @@ def _fit(self, pairs, y, bounds=None):
132132
print('itml converged at iter: %d, conv = %f' % (it, conv))
133133
self.n_iter_ = it
134134

135-
self.transformer_ = self._transformer_from_metric(self.A_)
135+
self.transformer_ = self.transformer_from_metric(self.A_)
136136
return self
137137

138138

139139
class ITML(_BaseITML, _PairsClassifierMixin):
140+
"""Information Theoretic Metric Learning (ITML)
141+
142+
Attributes
143+
----------
144+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
145+
The linear transformation ``L`` deduced from the learned Mahalanobis
146+
metric (See :meth:`transformer_from_metric`.)
147+
"""
140148

141149
def fit(self, pairs, y, bounds=None):
142150
"""Learn the ITML model.
@@ -159,7 +167,15 @@ def fit(self, pairs, y, bounds=None):
159167

160168

161169
class ITML_Supervised(_BaseITML, TransformerMixin):
162-
"""Information Theoretic Metric Learning (ITML)"""
170+
"""Supervised version of Information Theoretic Metric Learning (ITML)
171+
172+
Attributes
173+
----------
174+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
175+
The linear transformation ``L`` deduced from the learned Mahalanobis
176+
metric (See `transformer_from_metric`.)
177+
"""
178+
163179
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
164180
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
165181
verbose=False):
@@ -192,6 +208,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
192208
def fit(self, X, y, random_state=np.random):
193209
"""Create constraints from labels and learn the ITML model.
194210
211+
195212
Parameters
196213
----------
197214
X : (n x d) matrix

metric_learn/lfda.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ class LFDA(MahalanobisMixin, TransformerMixin):
2525
'''
2626
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
2727
Sugiyama, ICML 2006
28+
29+
Attributes
30+
----------
31+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
32+
The learned linear transformation ``L``.
2833
'''
34+
2935
def __init__(self, num_dims=None, k=None, embedding_type='weighted'):
3036
'''
3137
Initialize LFDA.

metric_learn/lmnn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):
243243
from modshogun import RealFeatures, MulticlassLabels
244244

245245
class LMNN(_base_LMNN):
246+
"""Large Margin Nearest Neighbor (LMNN)
247+
248+
Attributes
249+
----------
250+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
251+
The learned linear transformation ``L``.
252+
"""
246253

247254
def fit(self, X, y):
248255
self.X_, y = check_X_y(X, y, dtype=float)

metric_learn/lsml.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _fit(self, quadruplets, weights=None):
9999
print("Didn't converge after", it, "iterations. Final loss:", s_best)
100100
self.n_iter_ = it
101101

102-
self.transformer_ = self._transformer_from_metric(self.M_)
102+
self.transformer_ = self.transformer_from_metric(self.M_)
103103
return self
104104

105105
def _comparison_loss(self, metric):
@@ -129,6 +129,14 @@ def _gradient(self, metric):
129129

130130

131131
class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
132+
"""Least Squared-residual Metric Learning (LSML)
133+
134+
Attributes
135+
----------
136+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
137+
The linear transformation ``L`` deduced from the learned Mahalanobis
138+
metric (See :meth:`transformer_from_metric`.)
139+
"""
132140

133141
def fit(self, quadruplets, weights=None):
134142
"""Learn the LSML model.
@@ -152,6 +160,15 @@ def fit(self, quadruplets, weights=None):
152160

153161

154162
class LSML_Supervised(_BaseLSML, TransformerMixin):
163+
"""Supervised version of Least Squared-residual Metric Learning (LSML)
164+
165+
Attributes
166+
----------
167+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
168+
The linear transformation ``L`` deduced from the learned Mahalanobis
169+
metric (See :meth:`transformer_from_metric`.)
170+
"""
171+
155172
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
156173
num_constraints=None, weights=None, verbose=False):
157174
"""Initialize the learner.

metric_learn/mlkr.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121

2222

2323
class MLKR(MahalanobisMixin, TransformerMixin):
24-
"""Metric Learning for Kernel Regression (MLKR)"""
24+
"""Metric Learning for Kernel Regression (MLKR)
25+
26+
Attributes
27+
----------
28+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
29+
The learned linear transformation ``L``.
30+
"""
31+
2532
def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001,
2633
max_iter=1000):
2734
"""

metric_learn/mmc.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _fit_full(self, pairs, y):
217217
self.A_[:] = A_old
218218
self.n_iter_ = cycle
219219

220-
self.transformer_ = self._transformer_from_metric(self.A_)
220+
self.transformer_ = self.transformer_from_metric(self.A_)
221221
return self
222222

223223
def _fit_diag(self, pairs, y):
@@ -277,7 +277,7 @@ def _fit_diag(self, pairs, y):
277277

278278
self.A_ = np.diag(w)
279279

280-
self.transformer_ = self._transformer_from_metric(self.A_)
280+
self.transformer_ = self.transformer_from_metric(self.A_)
281281
return self
282282

283283
def _fD(self, neg_pairs, A):
@@ -359,6 +359,14 @@ def _D_constraint(self, neg_pairs, w):
359359

360360

361361
class MMC(_BaseMMC, _PairsClassifierMixin):
362+
"""Mahalanobis Metric for Clustering (MMC)
363+
364+
Attributes
365+
----------
366+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
367+
The linear transformation ``L`` deduced from the learned Mahalanobis
368+
metric (See :meth:`transformer_from_metric`.)
369+
"""
362370

363371
def fit(self, pairs, y):
364372
"""Learn the MMC model.
@@ -379,7 +387,15 @@ def fit(self, pairs, y):
379387

380388

381389
class MMC_Supervised(_BaseMMC, TransformerMixin):
382-
"""Mahalanobis Metric for Clustering (MMC)"""
390+
"""Supervised version of Mahalanobis Metric for Clustering (MMC)
391+
392+
Attributes
393+
----------
394+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
395+
The linear transformation ``L`` deduced from the learned Mahalanobis
396+
metric (See :meth:`transformer_from_metric`.)
397+
"""
398+
383399
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
384400
num_labeled=np.inf, num_constraints=None,
385401
A0=None, diagonal=False, diagonal_c=1.0, verbose=False):

metric_learn/nca.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515

1616

1717
class NCA(MahalanobisMixin, TransformerMixin):
18+
"""Neighborhood Components Analysis (NCA)
19+
20+
Attributes
21+
----------
22+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
23+
The learned linear transformation ``L``.
24+
"""
25+
1826
def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01):
1927
self.num_dims = num_dims
2028
self.max_iter = max_iter

metric_learn/rca.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ def _chunk_mean_centering(data, chunks):
3737

3838

3939
class RCA(MahalanobisMixin, TransformerMixin):
40-
"""Relevant Components Analysis (RCA)"""
40+
"""Relevant Components Analysis (RCA)
41+
42+
Attributes
43+
----------
44+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
45+
The learned linear transformation ``L``.
46+
"""
47+
4148
def __init__(self, num_dims=None, pca_comps=None):
4249
"""Initialize the learner.
4350
@@ -134,6 +141,14 @@ def _inv_sqrtm(x):
134141

135142

136143
class RCA_Supervised(RCA):
144+
"""Supervised version of Relevant Components Analysis (RCA)
145+
146+
Attributes
147+
----------
148+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
149+
The learned linear transformation ``L``.
150+
"""
151+
137152
def __init__(self, num_dims=None, pca_comps=None, num_chunks=100,
138153
chunk_size=2):
139154
"""Initialize the learner.

metric_learn/sdml.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,19 @@ def _fit(self, pairs, y):
6767
emp_cov = emp_cov.T.dot(emp_cov)
6868
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
6969

70-
self.transformer_ = self._transformer_from_metric(self.M_)
70+
self.transformer_ = self.transformer_from_metric(self.M_)
7171
return self
7272

7373

7474
class SDML(_BaseSDML, _PairsClassifierMixin):
75+
"""Sparse Distance Metric Learning (SDML)
76+
77+
Attributes
78+
----------
79+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
80+
The linear transformation ``L`` deduced from the learned Mahalanobis
81+
metric (See :meth:`transformer_from_metric`.)
82+
"""
7583

7684
def fit(self, pairs, y):
7785
"""Learn the SDML model.
@@ -92,6 +100,15 @@ def fit(self, pairs, y):
92100

93101

94102
class SDML_Supervised(_BaseSDML, TransformerMixin):
103+
"""Supervised version of Sparse Distance Metric Learning (SDML)
104+
105+
Attributes
106+
----------
107+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
108+
The linear transformation ``L`` deduced from the learned Mahalanobis
109+
metric (See :meth:`transformer_from_metric`.)
110+
"""
111+
95112
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
96113
num_labeled=np.inf, num_constraints=None, verbose=False):
97114
"""

0 commit comments

Comments
 (0)