Skip to content

Commit f48a55d

Browse files
authored
Revert changes in #270 due to revert decision in sklearn (#273)
1 parent 7a57b06 commit f48a55d

File tree

1 file changed

+6
-31
lines changed

1 file changed

+6
-31
lines changed

metric_learn/base_metric.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import six
1212
from ._util import ArrayIndexer, check_input, validate_vector
1313
import warnings
14-
import sys
1514

1615

1716
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):
@@ -241,22 +240,14 @@ def transform(self, X):
241240
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
242241
The embedded data points.
243242
"""
244-
# TODO: remove when we stop supporting Python < 3.5
245-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
246-
check_is_fitted(self, ['preprocessor_', 'components_'])
247-
else:
248-
check_is_fitted(self)
243+
check_is_fitted(self, ['preprocessor_', 'components_'])
249244
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
250245
preprocessor=self.preprocessor_,
251246
accept_sparse=True)
252247
return X_checked.dot(self.components_.T)
253248

254249
def get_metric(self):
255-
# TODO: remove when we stop supporting Python < 3.5
256-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
257-
check_is_fitted(self, 'components_')
258-
else:
259-
check_is_fitted(self)
250+
check_is_fitted(self, 'components_')
260251
components_T = self.components_.T.copy()
261252

262253
def metric_fun(u, v, squared=False):
@@ -309,11 +300,7 @@ def get_mahalanobis_matrix(self):
309300
M : `numpy.ndarray`, shape=(n_features, n_features)
310301
The copy of the learned Mahalanobis matrix.
311302
"""
312-
# TODO: remove when we stop supporting Python < 3.5
313-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
314-
check_is_fitted(self, 'components_')
315-
else:
316-
check_is_fitted(self)
303+
check_is_fitted(self, 'components_')
317304
return self.components_.T.dot(self.components_)
318305

319306

@@ -376,11 +363,7 @@ def decision_function(self, pairs):
376363
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
377364
The predicted decision function value for each pair.
378365
"""
379-
# TODO: remove when we stop supporting Python < 3.5
380-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
381-
check_is_fitted(self, 'preprocessor_')
382-
else:
383-
check_is_fitted(self)
366+
check_is_fitted(self, 'preprocessor_')
384367
pairs = check_input(pairs, type_of_inputs='tuples',
385368
preprocessor=self.preprocessor_,
386369
estimator=self, tuple_size=self._tuple_size)
@@ -623,11 +606,7 @@ def predict(self, quadruplets):
623606
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
624607
Predictions of the ordering of pairs, for each quadruplet.
625608
"""
626-
# TODO: remove when we stop supporting Python < 3.5
627-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
628-
check_is_fitted(self, 'preprocessor_')
629-
else:
630-
check_is_fitted(self)
609+
check_is_fitted(self, 'preprocessor_')
631610
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
632611
preprocessor=self.preprocessor_,
633612
estimator=self, tuple_size=self._tuple_size)
@@ -656,11 +635,7 @@ def decision_function(self, quadruplets):
656635
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
657636
Metric differences.
658637
"""
659-
# TODO: remove when we stop supporting Python < 3.5
660-
if sys.version_info.major < 3 or sys.version_info.minor < 5:
661-
check_is_fitted(self, 'preprocessor_')
662-
else:
663-
check_is_fitted(self)
638+
check_is_fitted(self, 'preprocessor_')
664639
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
665640
preprocessor=self.preprocessor_,
666641
estimator=self, tuple_size=self._tuple_size)

0 commit comments

Comments
 (0)