Skip to content

Commit 9fd186c

Browse files
committed
int input checks and tests
1 parent c02e6e5 commit 9fd186c

File tree

2 files changed

+87
-18
lines changed

2 files changed

+87
-18
lines changed

metric_learn/scml.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ def _fit(self, triplets, basis=None, n_basis=None):
4343
dual averaging method.
4444
"""
4545

46+
if not isinstance(self.max_iter, int):
47+
raise ValueError("max_iter should be an integer, instead it is of type"
48+
" %s" % type(self.max_iter))
49+
if not isinstance(self.output_iter, int):
50+
raise ValueError("output_iter should be an integer, instead it is of "
51+
"type %s" % type(self.output_iter))
52+
if not isinstance(self.batch_size, int):
53+
raise ValueError("batch_size should be an integer, instead it is of type"
54+
" %s" % type(self.batch_size))
55+
56+
if(self.output_iter > self.max_iter):
57+
raise ValueError("The value of output_iter must be equal or smaller than"
58+
" max_iter.")
59+
4660
# Currently prepare_inputs makes triplets contain points and not indices
4761
triplets = self._prepare_inputs(triplets, type_of_inputs='tuples')
4862

@@ -76,6 +90,23 @@ def _fit(self, triplets, basis=None, n_basis=None):
7690
rand_int = rng.randint(low=0, high=n_triplets,
7791
size=(self.max_iter, self.batch_size))
7892
for iter in range(self.max_iter):
93+
94+
idx = rand_int[iter]
95+
96+
slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
97+
slack_mask = np.squeeze(slack_val > 0, axis=1)
98+
99+
grad_w = np.sum(dist_diff[idx[slack_mask], :],
100+
axis=0, keepdims=True)/self.batch_size
101+
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
102+
103+
ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
104+
105+
scale_f = -(iter+1) / self.gamma / (delta + ada_grad_w)
106+
107+
# proximal operator with negative trimming equivalent
108+
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)
109+
79110
if (iter + 1) % self.output_iter == 0:
80111
# regularization part of obj function
81112
obj1 = np.sum(w)*self.beta
@@ -100,22 +131,6 @@ def _fit(self, triplets, basis=None, n_basis=None):
100131
best_obj = obj
101132
best_w = w
102133

103-
idx = rand_int[iter]
104-
105-
slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
106-
slack_mask = np.squeeze(slack_val > 0, axis=1)
107-
108-
grad_w = np.sum(dist_diff[idx[slack_mask], :],
109-
axis=0, keepdims=True)/self.batch_size
110-
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
111-
112-
ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
113-
114-
scale_f = -(iter+1) / self.gamma / (delta + ada_grad_w)
115-
116-
# proximal operator with negative trimming equivalent
117-
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)
118-
119134
if self.verbose:
120135
print("max iteration reached.")
121136

@@ -506,6 +521,13 @@ def fit(self, X, y):
506521

507522
basis, n_basis = self._initialize_basis_supervised(X, y)
508523

524+
if not isinstance(self.k_genuine, int):
525+
raise ValueError("k_genuine should be an integer, instead it is of type"
526+
" %s" % type(self.k_genuine))
527+
if not isinstance(self.k_impostor, int):
528+
raise ValueError("k_impostor should be an integer, instead it is of "
529+
"type %s" % type(self.k_impostor))
530+
509531
constraints = Constraints(y)
510532
triplets = constraints.generate_knntriplets(X, self.k_genuine,
511533
self.k_impostor)

test/metric_learn_test.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,11 @@ def test_array_basis(self, estimator, data):
203203
def test_verbose(self, estimator, data, capsys):
204204
# assert there is proper output when verbose = True
205205
model = estimator(preprocessor=np.array([[0, 0], [1, 1], [2, 2], [3, 3]]),
206-
max_iter=1, output_iter=1, batch_size=1, verbose=True)
206+
max_iter=1, output_iter=1, batch_size=1, basis='triplet_diffs',
207+
random_state=42, verbose=True)
207208
model.fit(*data)
208209
out, _ = capsys.readouterr()
209-
expected_out = ('[%s] iter 1\t obj 1.000000\t num_imp 8\n'
210+
expected_out = ('[%s] iter 1\t obj 0.569946\t num_imp 2\n'
210211
'max iteration reached.\n' % estimator.__name__)
211212
assert out == expected_out
212213

@@ -276,6 +277,52 @@ def test_lda(self, n_samples, n_features, n_classes):
276277
assert n_basis == expected_n_basis
277278
assert basis.shape == expected_shape
278279

280+
@pytest.mark.parametrize('name', ['max_iter', 'output_iter', 'batch_size',
281+
'n_basis'])
282+
def test_int_inputs(self, name):
283+
value = 1.0
284+
d = {name: value}
285+
scml = SCML(**d)
286+
triplets = np.array([[[0, 1], [2, 1], [0, 0]],
287+
[[2, 1], [0, 1], [2, 0]],
288+
[[0, 0], [2, 0], [0, 1]],
289+
[[2, 0], [0, 0], [2, 1]]])
290+
291+
msg = name
292+
msg += (" should be an integer, instead it is of type"
293+
" %s" % type(value))
294+
with pytest.raises(ValueError) as raised_error:
295+
scml.fit(triplets)
296+
assert msg == raised_error.value.args[0]
297+
298+
@pytest.mark.parametrize('name', ['max_iter', 'output_iter', 'batch_size',
299+
'k_genuine', 'k_impostor', 'n_basis'])
300+
def test_int_inputs_supervised(self, name):
301+
value = 1.0
302+
d = {name: value}
303+
scml = SCML_Supervised(**d)
304+
X = np.array([[0, 0], [1, 1], [3, 3], [4, 4]])
305+
y = np.array([1, 1, 0, 0])
306+
msg = name
307+
msg += (" should be an integer, instead it is of type"
308+
" %s" % type(value))
309+
with pytest.raises(ValueError) as raised_error:
310+
scml.fit(X, y)
311+
assert msg == raised_error.value.args[0]
312+
313+
def test_large_output_iter(self):
314+
scml = SCML(max_iter=1, output_iter=2)
315+
triplets = np.array([[[0, 1], [2, 1], [0, 0]],
316+
[[2, 1], [0, 1], [2, 0]],
317+
[[0, 0], [2, 0], [0, 1]],
318+
[[2, 0], [0, 0], [2, 1]]])
319+
msg = ("The value of output_iter must be equal or smaller than"
320+
" max_iter.")
321+
322+
with pytest.raises(ValueError) as raised_error:
323+
scml.fit(triplets)
324+
assert msg == raised_error.value.args[0]
325+
279326

280327
class TestLSML(MetricTestCase):
281328
def test_iris(self):

0 commit comments

Comments
 (0)