Skip to content

Commit 8f57db2

Browse files
wdevazelhesbellet
authored andcommitted
Add checks for labels when having pairs (#197)
1 parent d945df1 commit 8f57db2

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

metric_learn/_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def check_input(input_data, y=None, preprocessor=None,
137137
input_data = check_input_tuples(input_data, context, preprocessor,
138138
args_for_sk_checks, tuple_size)
139139

140+
# if we have y and the input data are pairs, we need to ensure
141+
# the labels are in [-1, 1]:
142+
if y is not None and input_data.shape[1] == 2:
143+
check_y_valid_values_for_pairs(y)
144+
140145
else:
141146
raise ValueError("Unknown value {} for type_of_inputs. Valid values are "
142147
"'classic' or 'tuples'.".format(type_of_inputs))
@@ -297,6 +302,13 @@ def check_tuple_size(tuples, tuple_size, context):
297302
raise ValueError(msg_t)
298303

299304

305+
def check_y_valid_values_for_pairs(y):
306+
"""Checks that y values are in [-1, 1]"""
307+
if not np.array_equal(np.abs(y), np.ones_like(y)):
308+
raise ValueError("When training on pairs, the labels (y) should contain "
309+
"only values in [-1, 1]. Found an incorrect value.")
310+
311+
300312
class ArrayIndexer:
301313

302314
def __init__(self, X):

test/test_utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from metric_learn._util import (check_input, make_context, preprocess_tuples,
1111
make_name, preprocess_points,
1212
check_collapsed_pairs, validate_vector,
13-
_check_sdp_from_eigen)
13+
_check_sdp_from_eigen,
14+
check_y_valid_values_for_pairs)
1415
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
1516
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
1617
MMC_Supervised, RCA_Supervised, SDML_Supervised,
@@ -1067,3 +1068,73 @@ def test_check_sdp_from_eigen_positive_err_messages():
10671068
_check_sdp_from_eigen(w, 1.)
10681069
_check_sdp_from_eigen(w, 0.)
10691070
_check_sdp_from_eigen(w, None)
1071+
1072+
1073+
@pytest.mark.unit
1074+
@pytest.mark.parametrize('wrong_labels',
1075+
[[0.5, 0.6, 0.7, 0.8, 0.9],
1076+
np.random.RandomState(42).randn(5),
1077+
np.random.RandomState(42).choice([0, 1], size=5)])
1078+
def test_check_y_valid_values_for_pairs(wrong_labels):
1079+
expected_msg = ("When training on pairs, the labels (y) should contain "
1080+
"only values in [-1, 1]. Found an incorrect value.")
1081+
with pytest.raises(ValueError) as raised_error:
1082+
check_y_valid_values_for_pairs(wrong_labels)
1083+
assert str(raised_error.value) == expected_msg
1084+
1085+
1086+
@pytest.mark.integration
1087+
@pytest.mark.parametrize('wrong_labels',
1088+
[[0.5, 0.6, 0.7, 0.8, 0.9],
1089+
np.random.RandomState(42).randn(5),
1090+
np.random.RandomState(42).choice([0, 1], size=5)])
1091+
def test_check_input_invalid_tuples_without_preprocessor(wrong_labels):
1092+
pairs = np.random.RandomState(42).randn(5, 2, 3)
1093+
expected_msg = ("When training on pairs, the labels (y) should contain "
1094+
"only values in [-1, 1]. Found an incorrect value.")
1095+
with pytest.raises(ValueError) as raised_error:
1096+
check_input(pairs, wrong_labels, preprocessor=None,
1097+
type_of_inputs='tuples')
1098+
assert str(raised_error.value) == expected_msg
1099+
1100+
1101+
@pytest.mark.integration
1102+
@pytest.mark.parametrize('wrong_labels',
1103+
[[0.5, 0.6, 0.7, 0.8, 0.9],
1104+
np.random.RandomState(42).randn(5),
1105+
np.random.RandomState(42).choice([0, 1], size=5)])
1106+
def test_check_input_invalid_tuples_with_preprocessor(wrong_labels):
1107+
n_samples, n_features, n_pairs = 10, 4, 5
1108+
rng = np.random.RandomState(42)
1109+
pairs = rng.randint(10, size=(n_pairs, 2))
1110+
preprocessor = rng.randn(n_samples, n_features)
1111+
expected_msg = ("When training on pairs, the labels (y) should contain "
1112+
"only values in [-1, 1]. Found an incorrect value.")
1113+
with pytest.raises(ValueError) as raised_error:
1114+
check_input(pairs, wrong_labels, preprocessor=ArrayIndexer(preprocessor),
1115+
type_of_inputs='tuples')
1116+
assert str(raised_error.value) == expected_msg
1117+
1118+
1119+
@pytest.mark.integration
1120+
@pytest.mark.parametrize('with_preprocessor', [True, False])
1121+
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
1122+
ids=ids_pairs_learners)
1123+
def test_check_input_pairs_learners_invalid_y(estimator, build_dataset,
1124+
with_preprocessor):
1125+
"""checks that the only allowed labels for learning pairs are +1 and -1"""
1126+
input_data, labels, _, X = build_dataset()
1127+
wrong_labels_list = [labels + 0.5,
1128+
np.random.RandomState(42).randn(len(labels)),
1129+
np.random.RandomState(42).choice([0, 1],
1130+
size=len(labels))]
1131+
model = clone(estimator)
1132+
set_random_state(model)
1133+
1134+
expected_msg = ("When training on pairs, the labels (y) should contain "
1135+
"only values in [-1, 1]. Found an incorrect value.")
1136+
1137+
for wrong_labels in wrong_labels_list:
1138+
with pytest.raises(ValueError) as raised_error:
1139+
model.fit(input_data, wrong_labels)
1140+
assert str(raised_error.value) == expected_msg

0 commit comments

Comments
 (0)