diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 427fcf86..ee73c793 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -93,6 +93,8 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic', The checked input labels array. """ self._check_preprocessor() + + check_is_fitted(self, ['preprocessor_']) return check_input(X, y, type_of_inputs=type_of_inputs, preprocessor=self.preprocessor_, @@ -215,6 +217,7 @@ def score_pairs(self, pairs): :ref:`mahalanobis_distances` : The section of the project documentation that describes Mahalanobis Distances. """ + check_is_fitted(self, ['preprocessor_']) pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=2) @@ -336,8 +339,10 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ + check_is_fitted(self, 'preprocessor_') + if "threshold_" not in vars(self): - msg = ("A threshold for this estimator has not been set," + msg = ("A threshold for this estimator has not been set, " "call its set_threshold or calibrate_threshold method.") raise AttributeError(msg) return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 @@ -414,6 +419,8 @@ def set_threshold(self, threshold): self : `_PairsClassifier` The pairs classifier with the new threshold set. """ + check_is_fitted(self, 'preprocessor_') + self.threshold_ = threshold return self @@ -476,6 +483,7 @@ def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy', -------- sklearn.calibration : scikit-learn's module for calibrating classifiers """ + check_is_fitted(self, 'preprocessor_') self._validate_calibration_params(strategy, min_rate, beta) diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 840cd151..6c71abcd 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -66,14 +66,31 @@ def test_predict_monotonous(estimator, build_dataset, ids=ids_pairs_learners) def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): - """Test that a NotFittedError is raised if someone tries to predict and - the metric learner has not been fitted.""" + """Test that a NotFittedError is raised if someone tries to use + score_pairs, decision_function, get_metric, transform or + get_mahalanobis_matrix on input data and the metric learner + has not been fitted.""" input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) + with pytest.raises(NotFittedError): + estimator.score_pairs(input_data) with pytest.raises(NotFittedError): estimator.decision_function(input_data) + with pytest.raises(NotFittedError): + estimator.get_metric() + with pytest.raises(NotFittedError): + estimator.transform(input_data) + with pytest.raises(NotFittedError): + estimator.get_mahalanobis_matrix() + with pytest.raises(NotFittedError): + estimator.calibrate_threshold(input_data, labels) + + with pytest.raises(NotFittedError): + estimator.set_threshold(0.5) + with pytest.raises(NotFittedError): + estimator.predict(input_data) @pytest.mark.parametrize('calibration_params', @@ -138,7 +155,8 @@ def fit(self, pairs, y): def test_unset_threshold(): - # test that set_threshold indeed sets the threshold + """Tests that the "threshold is unset" error is raised when using predict + (performs binary classification on pairs) with an unset threshold.""" identity_pairs_classifier = IdentityPairsClassifier() pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]]) y = np.array([1, 1, -1, -1]) @@ -146,7 +164,7 @@ def test_unset_threshold(): with pytest.raises(AttributeError) as e: identity_pairs_classifier.predict(pairs) - expected_msg = ("A threshold for this estimator has not been set," + expected_msg = ("A threshold for this estimator has not been set, " "call its set_threshold or calibrate_threshold method.") assert str(e.value) == expected_msg @@ -362,6 +380,7 @@ class MockBadPairsClassifier(MahalanobisMixin, _PairsClassifierMixin): """ def fit(self, pairs, y, calibration_params=None): + self.preprocessor_ = 'not used' self.components_ = 'not used' self.calibrate_threshold(pairs, y, **(calibration_params if calibration_params is not None else diff --git a/test/test_utils.py b/test/test_utils.py index 3092e168..76be5817 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -749,6 +749,8 @@ def test_array_like_indexer_array_like_valid_classic(input_data, indices): """Checks that any array-like is valid in the 'preprocessor' argument, and in the indices, for a classic input""" class MockMetricLearner(MahalanobisMixin): + def fit(self): + pass pass mock_algo = MockMetricLearner(preprocessor=input_data) @@ -763,6 +765,8 @@ def test_array_like_indexer_array_like_valid_tuples(input_data, indices): """Checks that any array-like is valid in the 'preprocessor' argument, and in the indices, for a classic input""" class MockMetricLearner(MahalanobisMixin): + def fit(self): + pass pass mock_algo = MockMetricLearner(preprocessor=input_data)