Skip to content

Commit 562f33b

Browse files
author
William de Vazelhes
committed
Add checks for bounds argument
1 parent f9511a0 commit 562f33b

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

metric_learn/itml.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def _fit(self, pairs, y, bounds=None):
6868
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
6969
self.bounds_ = np.percentile(pairwise_distances(X), (5, 95))
7070
else:
71-
assert len(bounds) == 2
71+
bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0,
72+
ensure_2d=False)
73+
bounds = bounds.ravel()
74+
if bounds.size != 2:
75+
raise ValueError("`bounds` should be an array-like of two elements.")
7276
self.bounds_ = bounds
73-
self.bounds_[self.bounds_==0] = 1e-9
77+
self.bounds_[self.bounds_ == 0] = 1e-9
7478
# init metric
7579
if self.A0 is None:
7680
A = np.identity(pairs.shape[2])
@@ -133,7 +137,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
133137
134138
Attributes
135139
----------
136-
bounds_ : array-like, shape=(2,)
140+
bounds_ : `numpy.ndarray`, shape=(2,)
137141
Bounds on similarity, aside slack variables, s.t.
138142
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
139143
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -170,7 +174,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
170174
preprocessor.
171175
y: array-like, of shape (n_constraints,)
172176
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
173-
bounds : `list` of two numbers
177+
bounds : array-like of two numbers
174178
Bounds on similarity, aside slack variables, s.t.
175179
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
176180
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -191,7 +195,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
191195
calibration_params = (calibration_params if calibration_params is not
192196
None else dict())
193197
self._validate_calibration_params(**calibration_params)
194-
self._fit(pairs, y)
198+
self._fit(pairs, y, bounds=bounds)
195199
self.calibrate_threshold(pairs, y, **calibration_params)
196200
return self
197201

@@ -201,7 +205,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
201205
202206
Attributes
203207
----------
204-
bounds_ : array-like, shape=(2,)
208+
bounds_ : `numpy.ndarray`, shape=(2,)
205209
Bounds on similarity, aside slack variables, s.t.
206210
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
207211
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -274,7 +278,7 @@ def fit(self, X, y, random_state=np.random, bounds=None):
274278
random_state : numpy.random.RandomState, optional
275279
If provided, controls random number generation.
276280
277-
bounds : `list` of two numbers
281+
bounds : array-like of two numbers
278282
Bounds on similarity, aside slack variables, s.t.
279283
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
280284
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of

test/metric_learn_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
HAS_SKGGM = True
1919
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
2020
LSML_Supervised, ITML_Supervised, SDML_Supervised,
21-
RCA_Supervised, MMC_Supervised, SDML)
21+
RCA_Supervised, MMC_Supervised, SDML, ITML)
2222
# Import this specially for testing.
2323
from metric_learn.constraints import wrap_pairs
2424
from metric_learn.lmnn import python_LMNN, _sum_outer_products
@@ -109,6 +109,43 @@ def test_deprecation_bounds(self):
109109
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)
110110

111111

112+
@pytest.mark.parametrize('bounds', [None, (20., 100.), [20., 100.],
113+
np.array([20., 100.]),
114+
np.array([[20., 100.]]),
115+
np.array([[20], [100]])])
116+
def test_bounds_parameters_valid(bounds):
117+
"""Asserts that we can provide any array-like of two elements as bounds,
118+
and that the attribute bound_ is a numpy array"""
119+
120+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
121+
y_pairs = [1, -1]
122+
itml = ITML()
123+
itml.fit(pairs, y_pairs, bounds=bounds)
124+
125+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
126+
y = np.array([1, 0, 1, 0])
127+
itml_supervised = ITML_Supervised()
128+
itml_supervised.fit(X, y, bounds=bounds)
129+
130+
131+
@pytest.mark.parametrize('bounds', ['weird', ['weird1', 'weird2'],
132+
np.array([1, 2, 3])])
133+
def test_bounds_parameters_invalid(bounds):
134+
"""Assert that if a non array-like is put for bounds, or an array-like
135+
of length different than 2, an error is returned"""
136+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
137+
y_pairs = [1, -1]
138+
itml = ITML()
139+
with pytest.raises(Exception):
140+
itml.fit(pairs, y_pairs, bounds=bounds)
141+
142+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
143+
y = np.array([1, 0, 1, 0])
144+
itml_supervised = ITML_Supervised()
145+
with pytest.raises(Exception):
146+
itml_supervised.fit(X, y, bounds=bounds)
147+
148+
112149
class TestLMNN(MetricTestCase):
113150
def test_iris(self):
114151
# Test both impls, if available.

0 commit comments

Comments
 (0)