Skip to content

Commit 585b5d2

Browse files
author
William de Vazelhes
committed
ENH: Add check_tuples
1 parent 810d191 commit 585b5d2

9 files changed

+136
-22
lines changed

metric_learn/_util.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,42 @@ def vector_norm(X):
99
return np.apply_along_axis(np.linalg.norm, 1, X)
1010
else:
1111
def vector_norm(X):
12-
return np.linalg.norm(X, axis=1)
12+
return np.linalg.norm(X, axis=1)
13+
14+
15+
def check_tuples(tuples):
16+
"""Check that the input is a valid 3D array representing a dataset of tuples.
17+
18+
Equivalent of `check_array` in scikit-learn.
19+
20+
Parameters
21+
----------
22+
tuples : object
23+
The tuples to check.
24+
25+
Returns
26+
-------
27+
tuples_valid : object
28+
The validated input.
29+
"""
30+
# If input is scalar raise error
31+
if len(tuples.shape) == 0:
32+
raise ValueError(
33+
"Expected 3D array, got scalar instead. Cannot apply this function on "
34+
"scalars.")
35+
# If input is 1D raise error
36+
if len(tuples.shape) == 1:
37+
raise ValueError(
38+
"Expected 3D array, got 1D array instead:\ntuples={}.\n"
39+
"Reshape your data using tuples.reshape(1, -1, 1) if it contains a "
40+
"single tuple and the points in the tuple have a single "
41+
"feature.").format(tuples)
42+
# If input is 2D raise error
43+
if len(tuples.shape) == 2:
44+
raise ValueError(
45+
"Expected 3D array, got 2D array instead:\ntuples={}.\n"
46+
"Reshape your data either using tuples.reshape(-1, {}, 1) if "
47+
"your data has a single feature or tuples.reshape(1, {}, -1) "
48+
"if it contains a single tuple.".format(tuples, tuples.shape[1],
49+
tuples.shape[0]))
50+
return tuples

metric_learn/base_metric.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from abc import ABCMeta, abstractmethod
77
import six
8+
from ._util import check_tuples
89

910

1011
class BaseMetricLearner(BaseEstimator):
@@ -86,7 +87,8 @@ def score_pairs(self, pairs):
8687
scores: `numpy.ndarray` of shape=(n_pairs,)
8788
The learned Mahalanobis distance for every pair.
8889
"""
89-
pairwise_diffs = self.transform(pairs[..., 1, :] - pairs[..., 0, :])
90+
pairs = check_tuples(pairs)
91+
pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :])
9092
# (for MahalanobisMixin, the embedding is linear so we can just embed the
9193
# difference)
9294
return np.sqrt(np.sum(pairwise_diffs**2, axis=-1))
@@ -108,7 +110,7 @@ def transform(self, X):
108110
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
109111
The embedded data points.
110112
"""
111-
X_checked = check_array(X, accept_sparse=True, ensure_2d=False)
113+
X_checked = check_array(X, accept_sparse=True)
112114
return X_checked.dot(self.transformer_.T)
113115

114116
def metric(self):
@@ -159,9 +161,11 @@ def predict(self, pairs):
159161
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
160162
The predicted learned metric value between samples in every pair.
161163
"""
164+
pairs = check_tuples(pairs)
162165
return self.score_pairs(pairs)
163166

164167
def decision_function(self, pairs):
168+
pairs = check_tuples(pairs)
165169
return self.predict(pairs)
166170

167171
def score(self, pairs, y):
@@ -187,6 +191,7 @@ def score(self, pairs, y):
187191
score : float
188192
The ``roc_auc`` score.
189193
"""
194+
pairs = check_tuples(pairs)
190195
return roc_auc_score(y, self.decision_function(pairs))
191196

192197

@@ -208,6 +213,7 @@ def predict(self, quadruplets):
208213
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
209214
Predictions of the ordering of pairs, for each quadruplet.
210215
"""
216+
quadruplets = check_tuples(quadruplets)
211217
return np.sign(self.decision_function(quadruplets))
212218

213219
def decision_function(self, quadruplets):
@@ -226,8 +232,9 @@ def decision_function(self, quadruplets):
226232
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
227233
Metric differences.
228234
"""
229-
return (self.score_pairs(quadruplets[..., :2, :]) -
230-
self.score_pairs(quadruplets[..., 2:, :]))
235+
quadruplets = check_tuples(quadruplets)
236+
return (self.score_pairs(quadruplets[:, :2, :]) -
237+
self.score_pairs(quadruplets[:, 2:, :]))
231238

232239
def score(self, quadruplets, y=None):
233240
"""Computes score on input quadruplets
@@ -248,4 +255,5 @@ def score(self, quadruplets, y=None):
248255
score : float
249256
The quadruplets score.
250257
"""
258+
quadruplets = check_tuples(quadruplets)
251259
return - np.mean(self.predict(quadruplets))

metric_learn/itml.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sklearn.base import TransformerMixin
2222
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2323
from .constraints import Constraints, wrap_pairs
24-
from ._util import vector_norm
24+
from ._util import vector_norm, check_tuples
2525

2626

2727
class _BaseITML(MahalanobisMixin):
@@ -52,8 +52,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
5252
self.verbose = verbose
5353

5454
def _process_pairs(self, pairs, y, bounds):
55+
# for now we check_X_y and check_tuples but we should only
56+
# check_tuples_y in the future
5557
pairs, y = check_X_y(pairs, y, accept_sparse=False,
5658
ensure_2d=False, allow_nd=True)
59+
pairs = check_tuples(pairs)
5760

5861
# check to make sure that no two constrained vectors are identical
5962
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]

metric_learn/lsml.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from six.moves import xrange
1414
from sklearn.base import TransformerMixin
1515
from sklearn.utils.validation import check_array, check_X_y
16+
from ._util import check_tuples
1617

1718
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1819
from .constraints import Constraints
@@ -37,8 +38,11 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
3738
self.verbose = verbose
3839

3940
def _prepare_quadruplets(self, quadruplets, weights):
40-
pairs = check_array(quadruplets, accept_sparse=False,
41-
ensure_2d=False, allow_nd=True)
41+
# for now we check_array and check_tuples but we should only
42+
# check_tuples in the future (with enhanced check_tuples)
43+
quadruplets = check_array(quadruplets, accept_sparse=False,
44+
ensure_2d=False, allow_nd=True)
45+
quadruplets = check_tuples(quadruplets)
4246

4347
# check to make sure that no two constrained vectors are identical
4448
self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :]
@@ -51,7 +55,8 @@ def _prepare_quadruplets(self, quadruplets, weights):
5155
self.w_ = weights
5256
self.w_ /= self.w_.sum() # weights must sum to 1
5357
if self.prior is None:
54-
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
58+
X = np.vstack({tuple(row) for row in
59+
quadruplets.reshape(-1, quadruplets.shape[2])})
5560
self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False))
5661
self.M_ = np.linalg.inv(self.prior_inv_)
5762
else:

metric_learn/mmc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2626
from .constraints import Constraints, wrap_pairs
27-
from ._util import vector_norm
27+
from ._util import vector_norm, check_tuples
2828

2929

3030
class _BaseMMC(MahalanobisMixin):
@@ -65,8 +65,11 @@ def _fit(self, pairs, y):
6565
return self._fit_full(pairs, y)
6666

6767
def _process_pairs(self, pairs, y):
68+
# for now we check_X_y and check_tuples but we should only
69+
# check_tuples_y in the future
6870
pairs, y = check_X_y(pairs, y, accept_sparse=False,
69-
ensure_2d=False, allow_nd=True)
71+
ensure_2d=False, allow_nd=True)
72+
pairs = check_tuples(pairs)
7073

7174
# check to make sure that no two constrained vectors are identical
7275
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]

metric_learn/sdml.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
1919
from .constraints import Constraints, wrap_pairs
20+
from ._util import check_tuples
2021

2122

2223
class _BaseSDML(MahalanobisMixin):
@@ -43,8 +44,12 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
4344
self.verbose = verbose
4445

4546
def _prepare_pairs(self, pairs, y):
47+
# for now we check_X_y and check_tuples but we should only
48+
# check_tuples_y in the future
4649
pairs, y = check_X_y(pairs, y, accept_sparse=False,
47-
ensure_2d=False, allow_nd=True)
50+
ensure_2d=False, allow_nd=True)
51+
pairs = check_tuples(pairs)
52+
4853
# set up prior M
4954
if self.use_cov:
5055
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})

test/test_mahalanobis_mixin.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,21 @@ def test_score_pairs_finite(estimator, build_dataset):
128128
ids=ids_estimators)
129129
def test_score_pairs_dim(estimator, build_dataset):
130130
# scoring of 3D arrays should return 1D array (several tuples),
131-
# and scoring of 2D arrays (one tuple) should return a scalar (0D array).
131+
# and scoring of 2D arrays (one tuple) should return an error (like
132+
# scikit-learn's error when scoring 1D arrays)
132133
inputs, labels = build_dataset()
133134
model = clone(estimator)
134135
model.fit(inputs, labels)
135136
X, _ = load_iris(return_X_y=True)
136137
tuples = np.array(list(product(X, X)))
137138
assert model.score_pairs(tuples).shape == (tuples.shape[0],)
138-
assert np.isscalar(model.score_pairs(tuples[1]))
139+
msg = ("Expected 3D array, got 2D array instead:\ntuples={}.\n"
140+
"Reshape your data either using tuples.reshape(-1, {}, 1) if "
141+
"your data has a single feature or tuples.reshape(1, {}, -1) "
142+
"if it contains a single tuple.".format(tuples, tuples.shape[1],
143+
tuples.shape[0]))
144+
with pytest.raises(ValueError, message=msg):
145+
model.score_pairs(tuples[1])
139146

140147

141148
def check_is_distance_matrix(pairwise):
@@ -174,13 +181,22 @@ def test_embed_dim(estimator, build_dataset):
174181
model.fit(inputs, labels)
175182
X, _ = load_iris(return_X_y=True)
176183
assert model.transform(X).shape == X.shape
177-
assert model.transform(X[0, :]).shape == (len(X[0]),)
184+
185+
# assert that ValueError is thrown if input shape is 1D
186+
err_msg = ("Expected 2D array, got 1D array instead:\narray={}.\n"
187+
"Reshape your data either using array.reshape(-1, 1) if "
188+
"your data has a single feature or array.reshape(1, -1) "
189+
"if it contains a single sample.".format(X))
190+
with pytest.raises(ValueError, message=err_msg):
191+
model.score_pairs(model.transform(X[0, :]))
178192
# we test that the shape is also OK when doing dimensionality reduction
179193
if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}:
180194
model.set_params(num_dims=2)
181195
model.fit(inputs, labels)
182196
assert model.transform(X).shape == (X.shape[0], 2)
183-
assert model.transform(X[0, :]).shape == (2,)
197+
# assert that ValueError is thrown if input shape is 1D
198+
with pytest.raises(ValueError, message=err_msg):
199+
model.transform(model.transform(X[0, :]))
184200

185201

186202
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,

test/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import pytest
3+
from metric_learn._util import check_tuples
4+
5+
6+
def test_check_tuples():
7+
X = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
8+
check_tuples(X)
9+
10+
X = np.array(5)
11+
msg = ("Expected 3D array, got scalar instead. Cannot apply this function "
12+
"on scalars.")
13+
with pytest.raises(ValueError, message=msg):
14+
check_tuples(X)
15+
16+
X = np.array([1, 2, 3])
17+
msg = ("Expected 3D array, got 1D array instead:\ntuples=[1, 2, 3].\n"
18+
"Reshape your data using tuples.reshape(1, -1, 1) if it contains a "
19+
"single tuple and the points in the tuple have a single feature.")
20+
with pytest.raises(ValueError, message=msg):
21+
check_tuples(X)
22+
23+
X = np.array([[1, 2, 3], [2, 3, 5]])
24+
msg = ("Expected 3D array, got 2D array instead:\ntuples=[[1, 2, 3], "
25+
"[2, 3, 5]].\nReshape your data either using "
26+
"tuples.reshape(-1, 3, 1) if your data has a single feature or "
27+
"tuples.reshape(1, 2, -1) if it contains a single tuple.")
28+
with pytest.raises(ValueError, message=msg):
29+
check_tuples(X)

test/test_weakly_supervised.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,24 @@ def test_dict_unchanged(estimator, build_dataset):
169169
(tuples, y, tuples_train, tuples_test,
170170
y_train, y_test) = build_dataset()
171171
estimator = clone(estimator)
172-
if hasattr(estimator, "n_components"):
173-
estimator.n_components = 1
172+
if hasattr(estimator, "num_dims"):
173+
estimator.num_dims = 1
174174
estimator.fit(tuples, y)
175-
for method in ["predict", "transform", "decision_function",
176-
"predict_proba"]:
175+
for method in ["predict", "decision_function", "predict_proba"]:
177176
if hasattr(estimator, method):
178177
dict_before = estimator.__dict__.copy()
179178
getattr(estimator, method)(tuples)
180179
assert estimator.__dict__ == dict_before, \
181180
("Estimator changes __dict__ during %s"
182181
% method)
182+
for method in ["transform"]:
183+
if hasattr(estimator, method):
184+
dict_before = estimator.__dict__.copy()
185+
# we transform only 2D arrays (dataset of points)
186+
getattr(estimator, method)(tuples[:, 0, :])
187+
assert estimator.__dict__ == dict_before, \
188+
("Estimator changes __dict__ during %s"
189+
% method)
183190

184191

185192
@pytest.mark.parametrize('estimator, build_dataset', list_estimators,
@@ -190,8 +197,8 @@ def test_dont_overwrite_parameters(estimator, build_dataset):
190197
(tuples, y, tuples_train, tuples_test,
191198
y_train, y_test) = build_dataset()
192199
estimator = clone(estimator)
193-
if hasattr(estimator, "n_components"):
194-
estimator.n_components = 1
200+
if hasattr(estimator, "num_dims"):
201+
estimator.num_dims = 1
195202
dict_before_fit = estimator.__dict__.copy()
196203

197204
estimator.fit(tuples, y)

0 commit comments

Comments
 (0)