Skip to content

Commit b741a9e

Browse files
author
William de Vazelhes
committed
FIX: corrections according to reviews #95 (review) and #95 (review)
- replace similarity by metric - replace constrained dataset by pairs/quadruplets - simplify score on quadruplets expression - replace ``X_constrained`` in tests by pairs/quadruplets/tuples
1 parent a70d1a8 commit b741a9e

File tree

2 files changed

+40
-41
lines changed

2 files changed

+40
-41
lines changed

metric_learn/base_metric.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ def transform(self, X=None):
5959
class _PairsClassifierMixin:
6060

6161
def predict(self, pairs):
62-
"""Predicts the learned similarity between input pairs.
62+
"""Predicts the learned metric between input pairs.
6363
6464
Returns the learned metric value between samples in every pair. It should
6565
ideally be low for similar samples and high for dissimilar samples.
6666
6767
Parameters
6868
----------
6969
pairs : array-like, shape=(n_constraints, 2, n_features)
70-
A constrained dataset of paired samples.
70+
Input pairs.
7171
7272
Returns
7373
-------
@@ -110,7 +110,7 @@ def score(self, pairs, y):
110110
class _QuadrupletsClassifierMixin:
111111

112112
def predict(self, quadruplets):
113-
"""Predicts differences between sample similarities in input quadruplets.
113+
"""Predicts differences between sample distances in input quadruplets.
114114
115115
For each quadruplet of samples, computes the difference between the learned
116116
metric of the first pair minus the learned metric of the second pair.
@@ -122,7 +122,7 @@ def predict(self, quadruplets):
122122
123123
Returns
124124
-------
125-
prediction : np.ndarray of floats, shape=(n_constraints,)
125+
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
126126
Metric differences.
127127
"""
128128
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
@@ -136,7 +136,7 @@ def decision_function(self, quadruplets):
136136
return self.predict(quadruplets)
137137

138138
def score(self, quadruplets, y=None):
139-
"""Computes score on an input constrained dataset
139+
"""Computes score on input quadruplets
140140
141141
Returns the accuracy score of the following classification task: a record
142142
is correctly classified if the predicted similarity between the first two
@@ -154,5 +154,4 @@ def score(self, quadruplets, y=None):
154154
score : float
155155
The quadruplets score.
156156
"""
157-
predicted_sign = self.decision_function(quadruplets) < 0
158-
return np.sum(predicted_sign) / predicted_sign.shape[0]
157+
return - np.mean(np.sign(self.decision_function(quadruplets)))

test/test_weakly_supervised.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@ def build_data():
2626

2727

2828
def build_pairs():
29-
# test that you can do cross validation on a ConstrainedDataset with
29+
# test that you can do cross validation on tuples of points with
3030
# a WeaklySupervisedMetricLearner
3131
X, pairs = build_data()
32-
X_constrained, y = wrap_pairs(X, pairs)
33-
X_constrained, y = shuffle(X_constrained, y)
34-
(X_constrained_train, X_constrained_test, y_train,
35-
y_test) = train_test_split(X_constrained, y)
36-
return (X_constrained, y, X_constrained_train, X_constrained_test,
32+
pairs, y = wrap_pairs(X, pairs)
33+
pairs, y = shuffle(pairs, y)
34+
(pairs_train, pairs_test, y_train,
35+
y_test) = train_test_split(pairs, y)
36+
return (pairs, y, pairs_train, pairs_test,
3737
y_train, y_test)
3838

3939

4040
def build_quadruplets():
41-
# test that you can do cross validation on a ConstrainedDataset with
41+
# test that you can do cross validation on a tuples of points with
4242
# a WeaklySupervisedMetricLearner
4343
X, pairs = build_data()
4444
c = np.column_stack(pairs)
45-
X_constrained = X[c]
46-
X_constrained = shuffle(X_constrained)
45+
quadruplets = X[c]
46+
quadruplets = shuffle(quadruplets)
4747
y = y_train = y_test = None
48-
X_constrained_train, X_constrained_test = train_test_split(X_constrained)
49-
return (X_constrained, y, X_constrained_train, X_constrained_test,
48+
quadruplets_train, quadruplets_test = train_test_split(quadruplets)
49+
return (quadruplets, y, quadruplets_train, quadruplets_test,
5050
y_train, y_test)
5151

5252

@@ -66,35 +66,35 @@ def build_quadruplets():
6666
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,
6767
ids=ids_estimators)
6868
def test_cross_validation(estimator, build_dataset):
69-
(X_constrained, y, X_constrained_train, X_constrained_test,
69+
(tuples, y, tuples_train, tuples_test,
7070
y_train, y_test) = build_dataset()
7171
estimator = clone(estimator)
7272
set_random_state(estimator)
7373

74-
assert np.isfinite(cross_val_score(estimator, X_constrained, y)).all()
74+
assert np.isfinite(cross_val_score(estimator, tuples, y)).all()
7575

7676

77-
def check_score(estimator, X_constrained, y):
78-
score = estimator.score(X_constrained, y)
77+
def check_score(estimator, tuples, y):
78+
score = estimator.score(tuples, y)
7979
assert np.isfinite(score)
8080

8181

82-
def check_predict(estimator, X_constrained):
83-
y_predicted = estimator.predict(X_constrained)
84-
assert len(y_predicted), len(X_constrained)
82+
def check_predict(estimator, tuples):
83+
y_predicted = estimator.predict(tuples)
84+
assert len(y_predicted), len(tuples)
8585

8686

8787
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,
8888
ids=ids_estimators)
8989
def test_simple_estimator(estimator, build_dataset):
90-
(X_constrained, y, X_constrained_train, X_constrained_test,
90+
(tuples, y, tuples_train, tuples_test,
9191
y_train, y_test) = build_dataset()
9292
estimator = clone(estimator)
9393
set_random_state(estimator)
9494

95-
estimator.fit(X_constrained_train, y_train)
96-
check_score(estimator, X_constrained_test, y_test)
97-
check_predict(estimator, X_constrained_test)
95+
estimator.fit(tuples_train, y_train)
96+
check_score(estimator, tuples_test, y_test)
97+
check_predict(estimator, tuples_test)
9898

9999

100100
@pytest.mark.parametrize('estimator', [est[0] for est in list_estimators],
@@ -122,50 +122,50 @@ def test_no_fit_attributes_set_in_init(estimator):
122122
def test_estimators_fit_returns_self(estimator, build_dataset):
123123
"""Check if self is returned when calling fit"""
124124
# From scikit-learn
125-
(X_constrained, y, X_constrained_train, X_constrained_test,
125+
(tuples, y, tuples_train, tuples_test,
126126
y_train, y_test) = build_dataset()
127127
estimator = clone(estimator)
128-
assert estimator.fit(X_constrained, y) is estimator
128+
assert estimator.fit(tuples, y) is estimator
129129

130130

131131
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,
132132
ids=ids_estimators)
133133
def test_pipeline_consistency(estimator, build_dataset):
134134
# From scikit learn
135135
# check that make_pipeline(est) gives same score as est
136-
(X_constrained, y, X_constrained_train, X_constrained_test,
136+
(tuples, y, tuples_train, tuples_test,
137137
y_train, y_test) = build_dataset()
138138
estimator = clone(estimator)
139139
pipeline = make_pipeline(estimator)
140-
estimator.fit(X_constrained, y)
141-
pipeline.fit(X_constrained, y)
140+
estimator.fit(tuples, y)
141+
pipeline.fit(tuples, y)
142142

143143
funcs = ["score", "fit_transform"]
144144

145145
for func_name in funcs:
146146
func = getattr(estimator, func_name, None)
147147
if func is not None:
148148
func_pipeline = getattr(pipeline, func_name)
149-
result = func(X_constrained, y)
150-
result_pipe = func_pipeline(X_constrained, y)
149+
result = func(tuples, y)
150+
result_pipe = func_pipeline(tuples, y)
151151
assert_allclose_dense_sparse(result, result_pipe)
152152

153153

154154
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,
155155
ids=ids_estimators)
156156
def test_dict_unchanged(estimator, build_dataset):
157157
# From scikit-learn
158-
(X_constrained, y, X_constrained_train, X_constrained_test,
158+
(tuples, y, tuples_train, tuples_test,
159159
y_train, y_test) = build_dataset()
160160
estimator = clone(estimator)
161161
if hasattr(estimator, "n_components"):
162162
estimator.n_components = 1
163-
estimator.fit(X_constrained, y)
163+
estimator.fit(tuples, y)
164164
for method in ["predict", "transform", "decision_function",
165165
"predict_proba"]:
166166
if hasattr(estimator, method):
167167
dict_before = estimator.__dict__.copy()
168-
getattr(estimator, method)(X_constrained)
168+
getattr(estimator, method)(tuples)
169169
assert estimator.__dict__ == dict_before, \
170170
("Estimator changes __dict__ during %s"
171171
% method)
@@ -176,14 +176,14 @@ def test_dict_unchanged(estimator, build_dataset):
176176
def test_dont_overwrite_parameters(estimator, build_dataset):
177177
# From scikit-learn
178178
# check that fit method only changes or sets private attributes
179-
(X_constrained, y, X_constrained_train, X_constrained_test,
179+
(tuples, y, tuples_train, tuples_test,
180180
y_train, y_test) = build_dataset()
181181
estimator = clone(estimator)
182182
if hasattr(estimator, "n_components"):
183183
estimator.n_components = 1
184184
dict_before_fit = estimator.__dict__.copy()
185185

186-
estimator.fit(X_constrained, y)
186+
estimator.fit(tuples, y)
187187
dict_after_fit = estimator.__dict__
188188

189189
public_keys_after_fit = [key for key in dict_after_fit.keys()

0 commit comments

Comments
 (0)