@@ -69,9 +69,13 @@ def _fit(self, pairs, y, bounds=None):
69
69
X = np .vstack ({tuple (row ) for row in pairs .reshape (- 1 , pairs .shape [2 ])})
70
70
self .bounds_ = np .percentile (pairwise_distances (X ), (5 , 95 ))
71
71
else :
72
- assert len (bounds ) == 2
72
+ bounds = check_array (bounds , allow_nd = False , ensure_min_samples = 0 ,
73
+ ensure_2d = False )
74
+ bounds = bounds .ravel ()
75
+ if bounds .size != 2 :
76
+ raise ValueError ("`bounds` should be an array-like of two elements." )
73
77
self .bounds_ = bounds
74
- self .bounds_ [self .bounds_ == 0 ] = 1e-9
78
+ self .bounds_ [self .bounds_ == 0 ] = 1e-9
75
79
# init metric
76
80
if self .A0 is None :
77
81
A = np .identity (pairs .shape [2 ])
@@ -134,7 +138,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
134
138
135
139
Attributes
136
140
----------
137
- bounds_ : array-like , shape=(2,)
141
+ bounds_ : `numpy.ndarray` , shape=(2,)
138
142
Bounds on similarity, aside slack variables, s.t.
139
143
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
140
144
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -171,7 +175,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
171
175
preprocessor.
172
176
y: array-like, of shape (n_constraints,)
173
177
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
174
- bounds : `list` of two numbers
178
+ bounds : array-like of two numbers
175
179
Bounds on similarity, aside slack variables, s.t.
176
180
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
177
181
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -192,7 +196,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
192
196
calibration_params = (calibration_params if calibration_params is not
193
197
None else dict ())
194
198
self ._validate_calibration_params (** calibration_params )
195
- self ._fit (pairs , y )
199
+ self ._fit (pairs , y , bounds = bounds )
196
200
self .calibrate_threshold (pairs , y , ** calibration_params )
197
201
return self
198
202
@@ -202,7 +206,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
202
206
203
207
Attributes
204
208
----------
205
- bounds_ : array-like , shape=(2,)
209
+ bounds_ : `numpy.ndarray` , shape=(2,)
206
210
Bounds on similarity, aside slack variables, s.t.
207
211
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
208
212
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -275,7 +279,7 @@ def fit(self, X, y, random_state=np.random, bounds=None):
275
279
random_state : numpy.random.RandomState, optional
276
280
If provided, controls random number generation.
277
281
278
- bounds : `list` of two numbers
282
+ bounds : array-like of two numbers
279
283
Bounds on similarity, aside slack variables, s.t.
280
284
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
281
285
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
0 commit comments