diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index e7dbd608..9064c100 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -569,8 +569,14 @@ def set_threshold(self, threshold): The pairs classifier with the new threshold set. """ check_is_fitted(self, 'preprocessor_') - - self.threshold_ = threshold + try: + self.threshold_ = float(threshold) + except TypeError: + raise ValueError('Parameter threshold must be a real number. ' + 'Got {} instead.'.format(type(threshold))) + except ValueError: + raise ValueError('Parameter threshold must be a real number. ' + 'Got {} instead.'.format(type(threshold))) return self def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy', diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 714cbd08..6a725f23 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -180,6 +180,25 @@ def test_set_threshold(): assert identity_pairs_classifier.threshold_ == 0.5 +@pytest.mark.parametrize('value', ["ABC", None, [1, 2, 3], {'key': None}, + (1, 2), set(), + np.array([[[0.], [1.]], [[1.], [3.]]])]) +def test_set_wrong_type_threshold(value): + """ + Test that `set_threshold` indeed sets the threshold + and cannot accept nothing but float or integers, but + being permissive with boolean True=1.0 and False=0.0 + """ + model = IdentityPairsClassifier() + model.fit(np.array([[[0.], [1.]]]), np.array([1])) + msg = ('Parameter threshold must be a real number. ' + 'Got {} instead.'.format(type(value))) + + with pytest.raises(ValueError) as e: # String + model.set_threshold(value) + assert str(e.value).startswith(msg) + + def test_f_beta_1_is_f_1(): # test that putting beta to 1 indeed finds the best threshold to optimize # the f1_score