Skip to content

Commit 7700529

Browse files
RobinVogelbellet
authored andcommitted
Break chunks generation in RCA when not enough possible chunks, fixes issue #200 (#254)
* fixes issue 200 * maj * add max_chunks in error message * tests the building of chunks in constraints.py * corrected faulty generation * still small mistake at generation * encapsulate tests, modified message * Testing chunk generation in constraints
1 parent b871028 commit 7700529

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

metric_learn/constraints.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
7979
chunks = -np.ones_like(self.known_label_idx, dtype=int)
8080
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
8181
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))]
82+
max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds]))
83+
if max_chunks < num_chunks:
84+
raise ValueError(('Not enough possible chunks of %d elements in each'
85+
' class to form expected %d chunks - maximum number'
86+
' of chunks is %d'
87+
) % (chunk_size, num_chunks, max_chunks))
8288
idx = 0
8389
while idx < num_chunks and all_inds:
8490
if len(all_inds) == 1:
@@ -93,9 +99,6 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
9399
inds.difference_update(ii)
94100
chunks[ii] = idx
95101
idx += 1
96-
if idx < num_chunks:
97-
raise ValueError('Unable to make %d chunks of %d examples each' %
98-
(num_chunks, chunk_size))
99102
return chunks
100103

101104

test/test_constraints.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
import pytest
3+
import numpy as np
4+
from sklearn.utils import shuffle
5+
from metric_learn.constraints import Constraints
6+
7+
SEED = 42
8+
9+
10+
def gen_labels_for_chunks(num_chunks, chunk_size,
11+
n_classes=10, n_unknown_labels=5):
12+
"""Generates num_chunks*chunk_size labels that split in num_chunks chunks,
13+
that are homogeneous in the label."""
14+
assert min(num_chunks, chunk_size) > 0
15+
classes = shuffle(np.arange(n_classes), random_state=SEED)
16+
n_per_class = chunk_size * (num_chunks // n_classes)
17+
n_maj_class = chunk_size * num_chunks - n_per_class * (n_classes - 1)
18+
19+
first_labels = classes[0] * np.ones(n_maj_class, dtype=int)
20+
remaining_labels = np.concatenate([k * np.ones(n_per_class, dtype=int)
21+
for k in classes[1:]])
22+
unknown_labels = -1 * np.ones(n_unknown_labels, dtype=int)
23+
24+
labels = np.concatenate([first_labels, remaining_labels, unknown_labels])
25+
return shuffle(labels, random_state=SEED)
26+
27+
28+
@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)])
29+
def test_exact_num_points_for_chunks(num_chunks, chunk_size):
30+
"""Checks that the chunk generation works well with just enough points."""
31+
labels = gen_labels_for_chunks(num_chunks, chunk_size)
32+
33+
constraints = Constraints(labels)
34+
chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
35+
random_state=SEED)
36+
37+
chunk_no, size_each_chunk = np.unique(chunks, return_counts=True)
38+
39+
np.testing.assert_array_equal(size_each_chunk, chunk_size)
40+
assert chunk_no.shape[0] == num_chunks
41+
42+
43+
@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)])
44+
def test_chunk_case_one_miss_point(num_chunks, chunk_size):
45+
"""Checks that the chunk generation breaks when one point is missing."""
46+
labels = gen_labels_for_chunks(num_chunks, chunk_size)
47+
48+
assert len(labels) >= 1
49+
constraints = Constraints(labels[1:])
50+
with pytest.raises(ValueError) as e:
51+
constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
52+
random_state=SEED)
53+
54+
expected_message = (('Not enough possible chunks of %d elements in each'
55+
' class to form expected %d chunks - maximum number'
56+
' of chunks is %d'
57+
) % (chunk_size, num_chunks, num_chunks - 1))
58+
59+
assert str(e.value) == expected_message
60+
61+
62+
if __name__ == '__main__':
63+
unittest.main()

0 commit comments

Comments
 (0)