|
10 | 10 | from metric_learn._util import (check_input, make_context, preprocess_tuples,
|
11 | 11 | make_name, preprocess_points,
|
12 | 12 | check_collapsed_pairs, validate_vector,
|
13 |
| - _check_sdp_from_eigen) |
| 13 | + _check_sdp_from_eigen, |
| 14 | + check_y_valid_values_for_pairs) |
14 | 15 | from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
|
15 | 16 | LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
|
16 | 17 | MMC_Supervised, RCA_Supervised, SDML_Supervised,
|
@@ -1067,3 +1068,73 @@ def test_check_sdp_from_eigen_positive_err_messages():
|
1067 | 1068 | _check_sdp_from_eigen(w, 1.)
|
1068 | 1069 | _check_sdp_from_eigen(w, 0.)
|
1069 | 1070 | _check_sdp_from_eigen(w, None)
|
| 1071 | + |
| 1072 | + |
| 1073 | +@pytest.mark.unit |
| 1074 | +@pytest.mark.parametrize('wrong_labels', |
| 1075 | + [[0.5, 0.6, 0.7, 0.8, 0.9], |
| 1076 | + np.random.RandomState(42).randn(5), |
| 1077 | + np.random.RandomState(42).choice([0, 1], size=5)]) |
| 1078 | +def test_check_y_valid_values_for_pairs(wrong_labels): |
| 1079 | + expected_msg = ("When training on pairs, the labels (y) should contain " |
| 1080 | + "only values in [-1, 1]. Found an incorrect value.") |
| 1081 | + with pytest.raises(ValueError) as raised_error: |
| 1082 | + check_y_valid_values_for_pairs(wrong_labels) |
| 1083 | + assert str(raised_error.value) == expected_msg |
| 1084 | + |
| 1085 | + |
| 1086 | +@pytest.mark.integration |
| 1087 | +@pytest.mark.parametrize('wrong_labels', |
| 1088 | + [[0.5, 0.6, 0.7, 0.8, 0.9], |
| 1089 | + np.random.RandomState(42).randn(5), |
| 1090 | + np.random.RandomState(42).choice([0, 1], size=5)]) |
| 1091 | +def test_check_input_invalid_tuples_without_preprocessor(wrong_labels): |
| 1092 | + pairs = np.random.RandomState(42).randn(5, 2, 3) |
| 1093 | + expected_msg = ("When training on pairs, the labels (y) should contain " |
| 1094 | + "only values in [-1, 1]. Found an incorrect value.") |
| 1095 | + with pytest.raises(ValueError) as raised_error: |
| 1096 | + check_input(pairs, wrong_labels, preprocessor=None, |
| 1097 | + type_of_inputs='tuples') |
| 1098 | + assert str(raised_error.value) == expected_msg |
| 1099 | + |
| 1100 | + |
| 1101 | +@pytest.mark.integration |
| 1102 | +@pytest.mark.parametrize('wrong_labels', |
| 1103 | + [[0.5, 0.6, 0.7, 0.8, 0.9], |
| 1104 | + np.random.RandomState(42).randn(5), |
| 1105 | + np.random.RandomState(42).choice([0, 1], size=5)]) |
| 1106 | +def test_check_input_invalid_tuples_with_preprocessor(wrong_labels): |
| 1107 | + n_samples, n_features, n_pairs = 10, 4, 5 |
| 1108 | + rng = np.random.RandomState(42) |
| 1109 | + pairs = rng.randint(10, size=(n_pairs, 2)) |
| 1110 | + preprocessor = rng.randn(n_samples, n_features) |
| 1111 | + expected_msg = ("When training on pairs, the labels (y) should contain " |
| 1112 | + "only values in [-1, 1]. Found an incorrect value.") |
| 1113 | + with pytest.raises(ValueError) as raised_error: |
| 1114 | + check_input(pairs, wrong_labels, preprocessor=ArrayIndexer(preprocessor), |
| 1115 | + type_of_inputs='tuples') |
| 1116 | + assert str(raised_error.value) == expected_msg |
| 1117 | + |
| 1118 | + |
| 1119 | +@pytest.mark.integration |
| 1120 | +@pytest.mark.parametrize('with_preprocessor', [True, False]) |
| 1121 | +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, |
| 1122 | + ids=ids_pairs_learners) |
| 1123 | +def test_check_input_pairs_learners_invalid_y(estimator, build_dataset, |
| 1124 | + with_preprocessor): |
| 1125 | + """checks that the only allowed labels for learning pairs are +1 and -1""" |
| 1126 | + input_data, labels, _, X = build_dataset() |
| 1127 | + wrong_labels_list = [labels + 0.5, |
| 1128 | + np.random.RandomState(42).randn(len(labels)), |
| 1129 | + np.random.RandomState(42).choice([0, 1], |
| 1130 | + size=len(labels))] |
| 1131 | + model = clone(estimator) |
| 1132 | + set_random_state(model) |
| 1133 | + |
| 1134 | + expected_msg = ("When training on pairs, the labels (y) should contain " |
| 1135 | + "only values in [-1, 1]. Found an incorrect value.") |
| 1136 | + |
| 1137 | + for wrong_labels in wrong_labels_list: |
| 1138 | + with pytest.raises(ValueError) as raised_error: |
| 1139 | + model.fit(input_data, wrong_labels) |
| 1140 | + assert str(raised_error.value) == expected_msg |
0 commit comments