Skip to content

Commit 0370198

Browse files
RobinVogelbellet
authored andcommitted
Repairs chunk generation for unknown labels, solves issue #260 (#263)
* chunks return a map of index to chunk * maj * maj * remove storing of known labels * typo * no self.num_points * tests for unlabeled, repairs chunk generation * maj * testing diff features * corrected test * diff warning * maj * added parameter bound test
1 parent 7819e7c commit 0370198

File tree

4 files changed

+74
-29
lines changed

4 files changed

+74
-29
lines changed

metric_learn/constraints.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import warnings
77
from six.moves import xrange
8-
from scipy.sparse import coo_matrix
98
from sklearn.utils import check_random_state
109

1110
__all__ = ['Constraints']
@@ -20,21 +19,7 @@ class Constraints(object):
2019
def __init__(self, partial_labels):
2120
'''partial_labels : int arraylike, -1 indicating unknown label'''
2221
partial_labels = np.asanyarray(partial_labels, dtype=int)
23-
self.num_points, = partial_labels.shape
24-
self.known_label_idx, = np.where(partial_labels >= 0)
25-
self.known_labels = partial_labels[self.known_label_idx]
26-
27-
def adjacency_matrix(self, num_constraints, random_state=None):
28-
random_state = check_random_state(random_state)
29-
a, b, c, d = self.positive_negative_pairs(num_constraints,
30-
random_state=random_state)
31-
row = np.concatenate((a, c))
32-
col = np.concatenate((b, d))
33-
data = np.ones_like(row, dtype=int)
34-
data[len(a):] = -1
35-
adj = coo_matrix((data, (row, col)), shape=(self.num_points,) * 2)
36-
# symmetrize
37-
return adj + adj.T
22+
self.partial_labels = partial_labels
3823

3924
def positive_negative_pairs(self, num_constraints, same_length=False,
4025
random_state=None):
@@ -50,17 +35,19 @@ def positive_negative_pairs(self, num_constraints, same_length=False,
5035

5136
def _pairs(self, num_constraints, same_label=True, max_iter=10,
5237
random_state=np.random):
53-
num_labels = len(self.known_labels)
38+
known_label_idx, = np.where(self.partial_labels >= 0)
39+
known_labels = self.partial_labels[known_label_idx]
40+
num_labels = len(known_labels)
5441
ab = set()
5542
it = 0
5643
while it < max_iter and len(ab) < num_constraints:
5744
nc = num_constraints - len(ab)
5845
for aidx in random_state.randint(num_labels, size=nc):
5946
if same_label:
60-
mask = self.known_labels[aidx] == self.known_labels
47+
mask = known_labels[aidx] == known_labels
6148
mask[aidx] = False # avoid identity pairs
6249
else:
63-
mask = self.known_labels[aidx] != self.known_labels
50+
mask = known_labels[aidx] != known_labels
6451
b_choices, = np.where(mask)
6552
if len(b_choices) > 0:
6653
ab.add((aidx, random_state.choice(b_choices)))
@@ -69,16 +56,18 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10,
6956
warnings.warn("Only generated %d %s constraints (requested %d)" % (
7057
len(ab), 'positive' if same_label else 'negative', num_constraints))
7158
ab = np.array(list(ab)[:num_constraints], dtype=int)
72-
return self.known_label_idx[ab.T]
59+
return known_label_idx[ab.T]
7360

7461
def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
7562
"""
7663
the random state object to be passed must be a numpy random seed
7764
"""
7865
random_state = check_random_state(random_state)
79-
chunks = -np.ones_like(self.known_label_idx, dtype=int)
80-
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
81-
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))]
66+
chunks = -np.ones_like(self.partial_labels, dtype=int)
67+
uniq, lookup = np.unique(self.partial_labels, return_inverse=True)
68+
unknown_uniq = np.where(uniq < 0)[0]
69+
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))
70+
if c not in unknown_uniq]
8271
max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds]))
8372
if max_chunks < num_chunks:
8473
raise ValueError(('Not enough possible chunks of %d elements in each'

metric_learn/rca.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@ def __init__(self, n_components=None, num_dims='deprecated',
9393

9494
def _check_dimension(self, rank, X):
9595
d = X.shape[1]
96+
9697
if rank < d:
9798
warnings.warn('The inner covariance matrix is not invertible, '
9899
'so the transformation matrix may contain Nan values. '
99-
'You should reduce the dimensionality of your input,'
100+
'You should remove any linearly dependent features and/or '
101+
'reduce the dimensionality of your input, '
100102
'for instance using `sklearn.decomposition.PCA` as a '
101103
'preprocessing step.')
102104

@@ -241,4 +243,13 @@ def fit(self, X, y, random_state='deprecated'):
241243
chunks = Constraints(y).chunks(num_chunks=self.num_chunks,
242244
chunk_size=self.chunk_size,
243245
random_state=self.random_state)
246+
247+
if self.num_chunks * (self.chunk_size - 1) < X.shape[1]:
248+
warnings.warn('Due to the parameters of RCA_Supervised, '
249+
'the inner covariance matrix is not invertible, '
250+
'so the transformation matrix will contain Nan values. '
251+
'Increase the number or size of the chunks to correct '
252+
'this problem.'
253+
)
254+
244255
return RCA.fit(self, X, chunks)

test/metric_learn_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,9 +1100,11 @@ def test_rank_deficient_returns_warning(self):
11001100
rca = RCA()
11011101
msg = ('The inner covariance matrix is not invertible, '
11021102
'so the transformation matrix may contain Nan values. '
1103-
'You should reduce the dimensionality of your input,'
1103+
'You should remove any linearly dependent features and/or '
1104+
'reduce the dimensionality of your input, '
11041105
'for instance using `sklearn.decomposition.PCA` as a '
11051106
'preprocessing step.')
1107+
11061108
with pytest.warns(None) as raised_warnings:
11071109
rca.fit(X, y)
11081110
assert any(str(w.message) == msg for w in raised_warnings)
@@ -1136,6 +1138,41 @@ def test_changed_behaviour_warning_random_state(self):
11361138
rca_supervised.fit(X, y)
11371139
assert any(msg == str(wrn.message) for wrn in raised_warning)
11381140

1141+
def test_unknown_labels(self):
1142+
n = 200
1143+
num_chunks = 50
1144+
X, y = make_classification(random_state=42, n_samples=2 * n,
1145+
n_features=6, n_informative=6, n_redundant=0)
1146+
y2 = np.concatenate((y[:n], -np.ones(n)))
1147+
1148+
rca = RCA_Supervised(num_chunks=num_chunks, random_state=42)
1149+
rca.fit(X[:n], y[:n])
1150+
1151+
rca2 = RCA_Supervised(num_chunks=num_chunks, random_state=42)
1152+
rca2.fit(X, y2)
1153+
1154+
assert not np.any(np.isnan(rca.components_))
1155+
assert not np.any(np.isnan(rca2.components_))
1156+
1157+
np.testing.assert_array_equal(rca.components_, rca2.components_)
1158+
1159+
def test_bad_parameters(self):
1160+
n = 200
1161+
num_chunks = 3
1162+
X, y = make_classification(random_state=42, n_samples=n,
1163+
n_features=6, n_informative=6, n_redundant=0)
1164+
1165+
rca = RCA_Supervised(num_chunks=num_chunks, random_state=42)
1166+
msg = ('Due to the parameters of RCA_Supervised, '
1167+
'the inner covariance matrix is not invertible, '
1168+
'so the transformation matrix will contain Nan values. '
1169+
'Increase the number or size of the chunks to correct '
1170+
'this problem.'
1171+
)
1172+
with pytest.warns(None) as raised_warning:
1173+
rca.fit(X, y)
1174+
assert any(str(w.message) == msg for w in raised_warning)
1175+
11391176

11401177
@pytest.mark.parametrize('num_dims', [None, 2])
11411178
def test_deprecation_num_dims_rca(num_dims):

test/test_constraints.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import unittest
21
import pytest
32
import numpy as np
43
from sklearn.utils import shuffle
@@ -34,7 +33,8 @@ def test_exact_num_points_for_chunks(num_chunks, chunk_size):
3433
chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
3534
random_state=SEED)
3635

37-
chunk_no, size_each_chunk = np.unique(chunks, return_counts=True)
36+
chunk_no, size_each_chunk = np.unique(chunks[chunks >= 0],
37+
return_counts=True)
3838

3939
np.testing.assert_array_equal(size_each_chunk, chunk_size)
4040
assert chunk_no.shape[0] == num_chunks
@@ -59,5 +59,13 @@ def test_chunk_case_one_miss_point(num_chunks, chunk_size):
5959
assert str(e.value) == expected_message
6060

6161

62-
if __name__ == '__main__':
63-
unittest.main()
62+
@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)])
63+
def test_unknown_labels_not_in_chunks(num_chunks, chunk_size):
64+
"""Checks that unknown labels are not assigned to any chunk."""
65+
labels = gen_labels_for_chunks(num_chunks, chunk_size)
66+
67+
constraints = Constraints(labels)
68+
chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
69+
random_state=SEED)
70+
71+
assert np.all(chunks[labels < 0] < 0)

0 commit comments

Comments
 (0)