@@ -68,9 +68,13 @@ def _fit(self, pairs, y, bounds=None):
68
68
X = np .vstack ({tuple (row ) for row in pairs .reshape (- 1 , pairs .shape [2 ])})
69
69
self .bounds_ = np .percentile (pairwise_distances (X ), (5 , 95 ))
70
70
else :
71
- assert len (bounds ) == 2
71
+ bounds = check_array (bounds , allow_nd = False , ensure_min_samples = 0 ,
72
+ ensure_2d = False )
73
+ bounds = bounds .ravel ()
74
+ if bounds .size != 2 :
75
+ raise ValueError ("`bounds` should be an array-like of two elements." )
72
76
self .bounds_ = bounds
73
- self .bounds_ [self .bounds_ == 0 ] = 1e-9
77
+ self .bounds_ [self .bounds_ == 0 ] = 1e-9
74
78
# init metric
75
79
if self .A0 is None :
76
80
A = np .identity (pairs .shape [2 ])
@@ -133,7 +137,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
133
137
134
138
Attributes
135
139
----------
136
- bounds_ : array-like , shape=(2,)
140
+ bounds_ : `numpy.ndarray` , shape=(2,)
137
141
Bounds on similarity, aside slack variables, s.t.
138
142
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
139
143
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -170,7 +174,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
170
174
preprocessor.
171
175
y: array-like, of shape (n_constraints,)
172
176
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
173
- bounds : `list` of two numbers
177
+ bounds : array-like of two numbers
174
178
Bounds on similarity, aside slack variables, s.t.
175
179
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
176
180
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -191,7 +195,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
191
195
calibration_params = (calibration_params if calibration_params is not
192
196
None else dict ())
193
197
self ._validate_calibration_params (** calibration_params )
194
- self ._fit (pairs , y )
198
+ self ._fit (pairs , y , bounds = bounds )
195
199
self .calibrate_threshold (pairs , y , ** calibration_params )
196
200
return self
197
201
@@ -201,7 +205,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
201
205
202
206
Attributes
203
207
----------
204
- bounds_ : array-like , shape=(2,)
208
+ bounds_ : `numpy.ndarray` , shape=(2,)
205
209
Bounds on similarity, aside slack variables, s.t.
206
210
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
207
211
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -274,7 +278,7 @@ def fit(self, X, y, random_state=np.random, bounds=None):
274
278
random_state : numpy.random.RandomState, optional
275
279
If provided, controls random number generation.
276
280
277
- bounds : `list` of two numbers
281
+ bounds : array-like of two numbers
278
282
Bounds on similarity, aside slack variables, s.t.
279
283
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
280
284
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
0 commit comments