diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 6feccc72..427fcf86 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -240,12 +240,14 @@ def transform(self, X): X_embedded : `numpy.ndarray`, shape=(n_samples, n_components) The embedded data points. """ + check_is_fitted(self, ['preprocessor_', 'components_']) X_checked = check_input(X, type_of_inputs='classic', estimator=self, preprocessor=self.preprocessor_, accept_sparse=True) return X_checked.dot(self.components_.T) def get_metric(self): + check_is_fitted(self, 'components_') components_T = self.components_.T.copy() def metric_fun(u, v, squared=False): @@ -298,6 +300,7 @@ def get_mahalanobis_matrix(self): M : `numpy.ndarray`, shape=(n_features, n_features) The copy of the learned Mahalanobis matrix. """ + check_is_fitted(self, 'components_') return self.components_.T.dot(self.components_) @@ -333,7 +336,10 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ - check_is_fitted(self, ['threshold_', 'components_']) + if "threshold_" not in vars(self): + msg = ("A threshold for this estimator has not been set," + "call its set_threshold or calibrate_threshold method.") + raise AttributeError(msg) return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 def decision_function(self, pairs): @@ -357,6 +363,7 @@ def decision_function(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted decision function value for each pair. """ + check_is_fitted(self, 'preprocessor_') pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -599,7 +606,7 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - check_is_fitted(self, 'components_') + check_is_fitted(self, 'preprocessor_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -628,6 +635,7 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ + check_is_fitted(self, 'preprocessor_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index affc70f6..840cd151 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -73,7 +73,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) with pytest.raises(NotFittedError): - estimator.predict(input_data) + estimator.decision_function(input_data) @pytest.mark.parametrize('calibration_params', @@ -133,10 +133,25 @@ def fit(self, pairs, y): pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') self.components_ = np.atleast_2d(np.identity(pairs.shape[2])) - self.threshold_ = 'I am not set.' + # self.threshold_ is not set. return self +def test_unset_threshold(): + # test that set_threshold indeed sets the threshold + identity_pairs_classifier = IdentityPairsClassifier() + pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]]) + y = np.array([1, 1, -1, -1]) + identity_pairs_classifier.fit(pairs, y) + with pytest.raises(AttributeError) as e: + identity_pairs_classifier.predict(pairs) + + expected_msg = ("A threshold for this estimator has not been set," + "call its set_threshold or calibrate_threshold method.") + + assert str(e.value) == expected_msg + + def test_set_threshold(): # test that set_threshold indeed sets the threshold identity_pairs_classifier = IdentityPairsClassifier()