diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 427fcf86..5367a01e 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -11,6 +11,7 @@ import six from ._util import ArrayIndexer, check_input, validate_vector import warnings +import sys class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): @@ -240,14 +241,22 @@ def transform(self, X): X_embedded : `numpy.ndarray`, shape=(n_samples, n_components) The embedded data points. """ - check_is_fitted(self, ['preprocessor_', 'components_']) + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, ['preprocessor_', 'components_']) + else: + check_is_fitted(self) X_checked = check_input(X, type_of_inputs='classic', estimator=self, preprocessor=self.preprocessor_, accept_sparse=True) return X_checked.dot(self.components_.T) def get_metric(self): - check_is_fitted(self, 'components_') + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, 'components_') + else: + check_is_fitted(self) components_T = self.components_.T.copy() def metric_fun(u, v, squared=False): @@ -300,7 +309,11 @@ def get_mahalanobis_matrix(self): M : `numpy.ndarray`, shape=(n_features, n_features) The copy of the learned Mahalanobis matrix. """ - check_is_fitted(self, 'components_') + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, 'components_') + else: + check_is_fitted(self) return self.components_.T.dot(self.components_) @@ -363,7 +376,11 @@ def decision_function(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted decision function value for each pair. """ - check_is_fitted(self, 'preprocessor_') + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, 'preprocessor_') + else: + check_is_fitted(self) pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -606,7 +623,11 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - check_is_fitted(self, 'preprocessor_') + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, 'preprocessor_') + else: + check_is_fitted(self) quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -635,7 +656,11 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ - check_is_fitted(self, 'preprocessor_') + # TODO: remove when we stop supporting Python < 3.5 + if sys.version_info.major < 3 or sys.version_info.minor < 5: + check_is_fitted(self, 'preprocessor_') + else: + check_is_fitted(self) quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size)