Skip to content

Commit 7819e7c

Browse files
belletterrytangyuan
authored andcommitted
Fix failing tests in last build (#270)
* add print to test * fix * fix again * fix again * remove attributes from check_is_fitted * if condition based on python version * add TODO everywhere
1 parent 7af910f commit 7819e7c

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

metric_learn/base_metric.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import six
1212
from ._util import ArrayIndexer, check_input, validate_vector
1313
import warnings
14+
import sys
1415

1516

1617
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):
@@ -240,14 +241,22 @@ def transform(self, X):
240241
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
241242
The embedded data points.
242243
"""
243-
check_is_fitted(self, ['preprocessor_', 'components_'])
244+
# TODO: remove when we stop supporting Python < 3.5
245+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
246+
check_is_fitted(self, ['preprocessor_', 'components_'])
247+
else:
248+
check_is_fitted(self)
244249
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
245250
preprocessor=self.preprocessor_,
246251
accept_sparse=True)
247252
return X_checked.dot(self.components_.T)
248253

249254
def get_metric(self):
250-
check_is_fitted(self, 'components_')
255+
# TODO: remove when we stop supporting Python < 3.5
256+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
257+
check_is_fitted(self, 'components_')
258+
else:
259+
check_is_fitted(self)
251260
components_T = self.components_.T.copy()
252261

253262
def metric_fun(u, v, squared=False):
@@ -300,7 +309,11 @@ def get_mahalanobis_matrix(self):
300309
M : `numpy.ndarray`, shape=(n_features, n_features)
301310
The copy of the learned Mahalanobis matrix.
302311
"""
303-
check_is_fitted(self, 'components_')
312+
# TODO: remove when we stop supporting Python < 3.5
313+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
314+
check_is_fitted(self, 'components_')
315+
else:
316+
check_is_fitted(self)
304317
return self.components_.T.dot(self.components_)
305318

306319

@@ -363,7 +376,11 @@ def decision_function(self, pairs):
363376
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
364377
The predicted decision function value for each pair.
365378
"""
366-
check_is_fitted(self, 'preprocessor_')
379+
# TODO: remove when we stop supporting Python < 3.5
380+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
381+
check_is_fitted(self, 'preprocessor_')
382+
else:
383+
check_is_fitted(self)
367384
pairs = check_input(pairs, type_of_inputs='tuples',
368385
preprocessor=self.preprocessor_,
369386
estimator=self, tuple_size=self._tuple_size)
@@ -606,7 +623,11 @@ def predict(self, quadruplets):
606623
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
607624
Predictions of the ordering of pairs, for each quadruplet.
608625
"""
609-
check_is_fitted(self, 'preprocessor_')
626+
# TODO: remove when we stop supporting Python < 3.5
627+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
628+
check_is_fitted(self, 'preprocessor_')
629+
else:
630+
check_is_fitted(self)
610631
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
611632
preprocessor=self.preprocessor_,
612633
estimator=self, tuple_size=self._tuple_size)
@@ -635,7 +656,11 @@ def decision_function(self, quadruplets):
635656
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
636657
Metric differences.
637658
"""
638-
check_is_fitted(self, 'preprocessor_')
659+
# TODO: remove when we stop supporting Python < 3.5
660+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
661+
check_is_fitted(self, 'preprocessor_')
662+
else:
663+
check_is_fitted(self)
639664
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
640665
preprocessor=self.preprocessor_,
641666
estimator=self, tuple_size=self._tuple_size)

0 commit comments

Comments
 (0)