Skip to content

Commit 580d38d

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] fix quadruplets decision_function (#217)
* fix quadruplets decision_function * Address #217 (comment) * fix: I put the column at the wrong side, now it does some subsampling * Fix number of samples * let's try again with 30 samples * Use less chunks
1 parent 8c3cb3e commit 580d38d

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

metric_learn/base_metric.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,9 @@ def decision_function(self, quadruplets):
618618
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
619619
Metric differences.
620620
"""
621+
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
622+
preprocessor=self.preprocessor_,
623+
estimator=self, tuple_size=self._tuple_size)
621624
return (self.score_pairs(quadruplets[:, 2:]) -
622625
self.score_pairs(quadruplets[:, :2]))
623626

test/test_sklearn_compat.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,70 @@ def stable_init(self, n_components=None, pca_comps=None,
105105

106106
# ---------------------- Test scikit-learn compatibility ----------------------
107107

108+
def generate_array_like(input_data, labels=None):
109+
"""Helper function to generate array-like variants of numpy datasets,
110+
for testing purposes."""
111+
list_data = input_data.tolist()
112+
input_data_changed = [input_data, list_data, tuple(list_data)]
113+
if input_data.ndim >= 2:
114+
input_data_changed.append(tuple(tuple(x) for x in list_data))
115+
if input_data.ndim >= 3:
116+
input_data_changed.append(tuple(tuple(tuple(x) for x in y) for y in
117+
list_data))
118+
if input_data.ndim == 2:
119+
pd = pytest.importorskip('pandas')
120+
input_data_changed.append(pd.DataFrame(input_data))
121+
if labels is not None:
122+
labels_changed = [labels, list(labels), tuple(labels)]
123+
else:
124+
labels_changed = [labels]
125+
return input_data_changed, labels_changed
126+
127+
128+
@pytest.mark.integration
129+
@pytest.mark.parametrize('with_preprocessor', [True, False])
130+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
131+
ids=ids_metric_learners)
132+
def test_array_like_inputs(estimator, build_dataset, with_preprocessor):
133+
"""Test that metric-learners can have as input (of all functions that are
134+
applied on data) any array-like object."""
135+
input_data, labels, preprocessor, X = build_dataset(with_preprocessor)
136+
137+
# we subsample the data for the test to be more efficient
138+
input_data, _, labels, _ = train_test_split(input_data, labels,
139+
train_size=20)
140+
X = X[:10]
141+
142+
estimator = clone(estimator)
143+
estimator.set_params(preprocessor=preprocessor)
144+
set_random_state(estimator)
145+
input_variants, label_variants = generate_array_like(input_data, labels)
146+
for input_variant in input_variants:
147+
for label_variant in label_variants:
148+
estimator.fit(*remove_y_quadruplets(estimator, input_variant,
149+
label_variant))
150+
if hasattr(estimator, "predict"):
151+
estimator.predict(input_variant)
152+
if hasattr(estimator, "predict_proba"):
153+
estimator.predict_proba(input_variant) # anticipation in case some
154+
# time we have that, or if ppl want to contribute with new algorithms
155+
# it will be checked automatically
156+
if hasattr(estimator, "decision_function"):
157+
estimator.decision_function(input_variant)
158+
if hasattr(estimator, "score"):
159+
for label_variant in label_variants:
160+
estimator.score(*remove_y_quadruplets(estimator, input_variant,
161+
label_variant))
162+
163+
X_variants, _ = generate_array_like(X)
164+
for X_variant in X_variants:
165+
estimator.transform(X_variant)
166+
167+
pairs = np.array([[X[0], X[1]], [X[0], X[2]]])
168+
pairs_variants, _ = generate_array_like(pairs)
169+
for pairs_variant in pairs_variants:
170+
estimator.score_pairs(pairs_variant)
171+
108172

109173
@pytest.mark.parametrize('with_preprocessor', [True, False])
110174
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,

test/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def build_quadruplets(with_preprocessor=False):
118118
(ITML_Supervised(max_iter=5), build_classification),
119119
(LSML_Supervised(), build_classification),
120120
(MMC_Supervised(max_iter=5), build_classification),
121-
(RCA_Supervised(num_chunks=10), build_classification),
121+
(RCA_Supervised(num_chunks=5), build_classification),
122122
(SDML_Supervised(prior='identity', balance_param=1e-5),
123123
build_classification)]
124124
ids_classifiers = list(map(lambda x: x.__class__.__name__,

0 commit comments

Comments
 (0)