@@ -241,47 +241,6 @@ def transform(self, X):
241
241
return X_checked .dot (self .transformer_ .T )
242
242
243
243
def get_metric (self ):
244
- """Returns a function that takes as input two 1D arrays and outputs the
245
- learned metric score on these two points.
246
-
247
- This function will be independent from the metric learner that learned it
248
- (it will not be modified if the initial metric learner is modified),
249
- and it can be directly plugged into the `metric` argument of
250
- scikit-learn's estimators.
251
-
252
- Returns
253
- -------
254
- metric_fun : function
255
- The function described above.
256
-
257
- Examples
258
- --------
259
- .. doctest::
260
-
261
- >>> from metric_learn import NCA
262
- >>> from sklearn.datasets import make_classification
263
- >>> from sklearn.neighbors import KNeighborsClassifier
264
- >>> nca = NCA()
265
- >>> X, y = make_classification()
266
- >>> nca.fit(X, y)
267
- >>> knn = KNeighborsClassifier(metric=nca.get_metric())
268
- >>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
269
- KNeighborsClassifier(algorithm='auto', leaf_size=30,
270
- metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun
271
- at 0x...>,
272
- metric_params=None, n_jobs=None, n_neighbors=5, p=2,
273
- weights='uniform')
274
-
275
- See Also
276
- --------
277
- score_pairs : a method that returns the metric score between several pairs
278
- of points. Unlike `get_metric`, this is a method of the metric learner
279
- and therefore can change if the metric learner changes. Besides, it can
280
- use the metric learner's preprocessor, and works on concatenated arrays.
281
-
282
- :ref:`mahalanobis_distances` : The section of the project documentation
283
- that describes Mahalanobis Distances.
284
- """
285
244
transformer_T = self .transformer_ .T .copy ()
286
245
287
246
def metric_fun (u , v ):
@@ -304,12 +263,14 @@ def metric_fun(u, v):
304
263
return euclidean (u .dot (transformer_T ), v .dot (transformer_T ))
305
264
return metric_fun
306
265
266
+ get_metric .__doc__ = BaseMetricLearner .get_metric .__doc__
267
+
307
268
def metric (self ):
308
269
# TODO: remove this method in version 0.6.0
309
270
warnings .warn (("`metric` is deprecated since version 0.5.0 and will be "
310
271
"removed in 0.6.0. Use `get_mahalanobis_matrix` instead." ),
311
272
DeprecationWarning )
312
- return self .transformer_ . T . dot ( self . transformer_ )
273
+ return self .get_mahalanobis_matrix ( )
313
274
314
275
def get_mahalanobis_matrix (self ):
315
276
"""Returns a copy of the Mahalanobis matrix learned by the metric learner.
@@ -319,7 +280,7 @@ def get_mahalanobis_matrix(self):
319
280
M : `numpy.ndarray`, shape=(n_components, n_features)
320
281
The copy of the learned Mahalanobis matrix.
321
282
"""
322
- return self .transformer_ .T .dot (self .transformer_ ). copy ()
283
+ return self .transformer_ .T .dot (self .transformer_ )
323
284
324
285
325
286
class _PairsClassifierMixin (BaseMetricLearner ):
0 commit comments