|
| 1 | +import unittest |
| 2 | +import pytest |
| 3 | +import numpy as np |
| 4 | +from sklearn.utils import shuffle |
| 5 | +from metric_learn.constraints import Constraints |
| 6 | + |
| 7 | +SEED = 42 |
| 8 | + |
| 9 | + |
| 10 | +def gen_labels_for_chunks(num_chunks, chunk_size, |
| 11 | + n_classes=10, n_unknown_labels=5): |
| 12 | + """Generates num_chunks*chunk_size labels that split in num_chunks chunks, |
| 13 | + that are homogeneous in the label.""" |
| 14 | + assert min(num_chunks, chunk_size) > 0 |
| 15 | + classes = shuffle(np.arange(n_classes), random_state=SEED) |
| 16 | + n_per_class = chunk_size * (num_chunks // n_classes) |
| 17 | + n_maj_class = chunk_size * num_chunks - n_per_class * (n_classes - 1) |
| 18 | + |
| 19 | + first_labels = classes[0] * np.ones(n_maj_class, dtype=int) |
| 20 | + remaining_labels = np.concatenate([k * np.ones(n_per_class, dtype=int) |
| 21 | + for k in classes[1:]]) |
| 22 | + unknown_labels = -1 * np.ones(n_unknown_labels, dtype=int) |
| 23 | + |
| 24 | + labels = np.concatenate([first_labels, remaining_labels, unknown_labels]) |
| 25 | + return shuffle(labels, random_state=SEED) |
| 26 | + |
| 27 | + |
| 28 | +@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)]) |
| 29 | +def test_exact_num_points_for_chunks(num_chunks, chunk_size): |
| 30 | + """Checks that the chunk generation works well with just enough points.""" |
| 31 | + labels = gen_labels_for_chunks(num_chunks, chunk_size) |
| 32 | + |
| 33 | + constraints = Constraints(labels) |
| 34 | + chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size, |
| 35 | + random_state=SEED) |
| 36 | + |
| 37 | + chunk_no, size_each_chunk = np.unique(chunks, return_counts=True) |
| 38 | + |
| 39 | + np.testing.assert_array_equal(size_each_chunk, chunk_size) |
| 40 | + assert chunk_no.shape[0] == num_chunks |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)]) |
| 44 | +def test_chunk_case_one_miss_point(num_chunks, chunk_size): |
| 45 | + """Checks that the chunk generation breaks when one point is missing.""" |
| 46 | + labels = gen_labels_for_chunks(num_chunks, chunk_size) |
| 47 | + |
| 48 | + assert len(labels) >= 1 |
| 49 | + constraints = Constraints(labels[1:]) |
| 50 | + with pytest.raises(ValueError) as e: |
| 51 | + constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size, |
| 52 | + random_state=SEED) |
| 53 | + |
| 54 | + expected_message = (('Not enough possible chunks of %d elements in each' |
| 55 | + ' class to form expected %d chunks - maximum number' |
| 56 | + ' of chunks is %d' |
| 57 | + ) % (chunk_size, num_chunks, num_chunks - 1)) |
| 58 | + |
| 59 | + assert str(e.value) == expected_message |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == '__main__': |
| 63 | + unittest.main() |
0 commit comments