Skip to content

Commit 7af910f

Browse files
RobinVogelterrytangyuan
authored andcommitted
More systematic checks that an estimator was fit before using its parameters (#267)
* maj * added fit checks * maj * Added checks that the function was fitted. * check the input before if model is fitted * made more sensible checks. * added a test for a threshold * added a test for the unset threshold
1 parent 710379e commit 7af910f

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

metric_learn/base_metric.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,14 @@ def transform(self, X):
240240
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
241241
The embedded data points.
242242
"""
243+
check_is_fitted(self, ['preprocessor_', 'components_'])
243244
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
244245
preprocessor=self.preprocessor_,
245246
accept_sparse=True)
246247
return X_checked.dot(self.components_.T)
247248

248249
def get_metric(self):
250+
check_is_fitted(self, 'components_')
249251
components_T = self.components_.T.copy()
250252

251253
def metric_fun(u, v, squared=False):
@@ -298,6 +300,7 @@ def get_mahalanobis_matrix(self):
298300
M : `numpy.ndarray`, shape=(n_features, n_features)
299301
The copy of the learned Mahalanobis matrix.
300302
"""
303+
check_is_fitted(self, 'components_')
301304
return self.components_.T.dot(self.components_)
302305

303306

@@ -333,7 +336,10 @@ def predict(self, pairs):
333336
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
334337
The predicted learned metric value between samples in every pair.
335338
"""
336-
check_is_fitted(self, ['threshold_', 'components_'])
339+
if "threshold_" not in vars(self):
340+
msg = ("A threshold for this estimator has not been set,"
341+
"call its set_threshold or calibrate_threshold method.")
342+
raise AttributeError(msg)
337343
return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1
338344

339345
def decision_function(self, pairs):
@@ -357,6 +363,7 @@ def decision_function(self, pairs):
357363
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
358364
The predicted decision function value for each pair.
359365
"""
366+
check_is_fitted(self, 'preprocessor_')
360367
pairs = check_input(pairs, type_of_inputs='tuples',
361368
preprocessor=self.preprocessor_,
362369
estimator=self, tuple_size=self._tuple_size)
@@ -599,7 +606,7 @@ def predict(self, quadruplets):
599606
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
600607
Predictions of the ordering of pairs, for each quadruplet.
601608
"""
602-
check_is_fitted(self, 'components_')
609+
check_is_fitted(self, 'preprocessor_')
603610
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
604611
preprocessor=self.preprocessor_,
605612
estimator=self, tuple_size=self._tuple_size)
@@ -628,6 +635,7 @@ def decision_function(self, quadruplets):
628635
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
629636
Metric differences.
630637
"""
638+
check_is_fitted(self, 'preprocessor_')
631639
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
632640
preprocessor=self.preprocessor_,
633641
estimator=self, tuple_size=self._tuple_size)

test/test_pairs_classifiers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
7373
estimator.set_params(preprocessor=preprocessor)
7474
set_random_state(estimator)
7575
with pytest.raises(NotFittedError):
76-
estimator.predict(input_data)
76+
estimator.decision_function(input_data)
7777

7878

7979
@pytest.mark.parametrize('calibration_params',
@@ -133,10 +133,25 @@ def fit(self, pairs, y):
133133
pairs, y = self._prepare_inputs(pairs, y,
134134
type_of_inputs='tuples')
135135
self.components_ = np.atleast_2d(np.identity(pairs.shape[2]))
136-
self.threshold_ = 'I am not set.'
136+
# self.threshold_ is not set.
137137
return self
138138

139139

140+
def test_unset_threshold():
141+
# test that set_threshold indeed sets the threshold
142+
identity_pairs_classifier = IdentityPairsClassifier()
143+
pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]])
144+
y = np.array([1, 1, -1, -1])
145+
identity_pairs_classifier.fit(pairs, y)
146+
with pytest.raises(AttributeError) as e:
147+
identity_pairs_classifier.predict(pairs)
148+
149+
expected_msg = ("A threshold for this estimator has not been set,"
150+
"call its set_threshold or calibrate_threshold method.")
151+
152+
assert str(e.value) == expected_msg
153+
154+
140155
def test_set_threshold():
141156
# test that set_threshold indeed sets the threshold
142157
identity_pairs_classifier = IdentityPairsClassifier()

0 commit comments

Comments
 (0)