Skip to content

Commit 964f28d

Browse files
authored
Threshold must be a real number (#322)
* Add venv to gitignore * Check if threshold is a real value * Simplified threshold type-check * Follow linter rules * Fix last linter error * Add test to check correct behaviour. Sacrified simplicity for the bool case. * Update test. Stick to custom message. It's bool permissive * Explicit boolean permissive case in test * Changed isinstance for custom ValueError message * TypeError for most input. ValueError for String case.
1 parent bdfdb24 commit 964f28d

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

metric_learn/base_metric.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,14 @@ def set_threshold(self, threshold):
569569
The pairs classifier with the new threshold set.
570570
"""
571571
check_is_fitted(self, 'preprocessor_')
572-
573-
self.threshold_ = threshold
572+
try:
573+
self.threshold_ = float(threshold)
574+
except TypeError:
575+
raise ValueError('Parameter threshold must be a real number. '
576+
'Got {} instead.'.format(type(threshold)))
577+
except ValueError:
578+
raise ValueError('Parameter threshold must be a real number. '
579+
'Got {} instead.'.format(type(threshold)))
574580
return self
575581

576582
def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy',

test/test_pairs_classifiers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,25 @@ def test_set_threshold():
180180
assert identity_pairs_classifier.threshold_ == 0.5
181181

182182

183+
@pytest.mark.parametrize('value', ["ABC", None, [1, 2, 3], {'key': None},
184+
(1, 2), set(),
185+
np.array([[[0.], [1.]], [[1.], [3.]]])])
186+
def test_set_wrong_type_threshold(value):
187+
"""
188+
Test that `set_threshold` indeed sets the threshold
189+
and cannot accept nothing but float or integers, but
190+
being permissive with boolean True=1.0 and False=0.0
191+
"""
192+
model = IdentityPairsClassifier()
193+
model.fit(np.array([[[0.], [1.]]]), np.array([1]))
194+
msg = ('Parameter threshold must be a real number. '
195+
'Got {} instead.'.format(type(value)))
196+
197+
with pytest.raises(ValueError) as e: # String
198+
model.set_threshold(value)
199+
assert str(e.value).startswith(msg)
200+
201+
183202
def test_f_beta_1_is_f_1():
184203
# test that putting beta to 1 indeed finds the best threshold to optimize
185204
# the f1_score

0 commit comments

Comments
 (0)