Skip to content

Commit 890b6a8

Browse files
bhargavvaderperimosocordiae
authored andcommitted
[MRG] Added random_states (#35)
1 parent c74a058 commit 890b6a8

File tree

5 files changed

+31
-21
lines changed

5 files changed

+31
-21
lines changed

metric_learn/constraints.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def __init__(self, partial_labels):
1919
self.known_label_idx, = np.where(partial_labels >= 0)
2020
self.known_labels = partial_labels[self.known_label_idx]
2121

22-
def adjacency_matrix(self, num_constraints):
23-
a, b, c, d = self.positive_negative_pairs(num_constraints)
22+
def adjacency_matrix(self, num_constraints, random_state=np.random):
23+
a, b, c, d = self.positive_negative_pairs(num_constraints, random_state=random_state)
2424
row = np.concatenate((a, c))
2525
col = np.concatenate((b, d))
2626
data = np.ones_like(row, dtype=int)
@@ -29,48 +29,51 @@ def adjacency_matrix(self, num_constraints):
2929
# symmetrize
3030
return adj + adj.T
3131

32-
def positive_negative_pairs(self, num_constraints, same_length=False):
33-
a, b = self._pairs(num_constraints, same_label=True)
34-
c, d = self._pairs(num_constraints, same_label=False)
32+
def positive_negative_pairs(self, num_constraints, same_length=False, random_state=np.random):
33+
a, b = self._pairs(num_constraints, same_label=True, random_state=random_state)
34+
c, d = self._pairs(num_constraints, same_label=False, random_state=random_state)
3535
if same_length and len(a) != len(c):
3636
n = min(len(a), len(c))
3737
return a[:n], b[:n], c[:n], d[:n]
3838
return a, b, c, d
3939

40-
def _pairs(self, num_constraints, same_label=True, max_iter=10):
40+
def _pairs(self, num_constraints, same_label=True, max_iter=10, random_state=np.random):
4141
num_labels = len(self.known_labels)
4242
ab = set()
4343
it = 0
4444
while it < max_iter and len(ab) < num_constraints:
4545
nc = num_constraints - len(ab)
46-
for aidx in np.random.randint(num_labels, size=nc):
46+
for aidx in random_state.randint(num_labels, size=nc):
4747
if same_label:
4848
mask = self.known_labels[aidx] == self.known_labels
4949
mask[aidx] = False # avoid identity pairs
5050
else:
5151
mask = self.known_labels[aidx] != self.known_labels
5252
b_choices, = np.where(mask)
5353
if len(b_choices) > 0:
54-
ab.add((aidx, np.random.choice(b_choices)))
54+
ab.add((aidx, random_state.choice(b_choices)))
5555
it += 1
5656
if len(ab) < num_constraints:
5757
warnings.warn("Only generated %d %s constraints (requested %d)" % (
5858
len(ab), 'positive' if same_label else 'negative', num_constraints))
5959
ab = np.array(list(ab)[:num_constraints], dtype=int)
6060
return self.known_label_idx[ab.T]
6161

62-
def chunks(self, num_chunks=100, chunk_size=2):
62+
def chunks(self, num_chunks=100, chunk_size=2, random_state=np.random):
63+
"""
64+
the random state object to be passed must be a numpy random seed
65+
"""
6366
chunks = -np.ones_like(self.known_label_idx, dtype=int)
6467
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
6568
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
6669
idx = 0
6770
while idx < num_chunks and all_inds:
68-
c = random.randint(0, len(all_inds)-1)
71+
c = random_state.randint(0, high=len(all_inds)-1)
6972
inds = all_inds[c]
7073
if len(inds) < chunk_size:
7174
del all_inds[c]
7275
continue
73-
ii = random.sample(inds, chunk_size)
76+
ii = random_state.choice(list(inds), chunk_size, replace=False)
7477
inds.difference_update(ii)
7578
chunks[ii] = idx
7679
idx += 1
@@ -80,10 +83,13 @@ def chunks(self, num_chunks=100, chunk_size=2):
8083
return chunks
8184

8285
@staticmethod
83-
def random_subset(all_labels, num_preserved=np.inf):
86+
def random_subset(all_labels, num_preserved=np.inf, random_state=np.random):
87+
"""
88+
the random state object to be passed must be a numpy random seed
89+
"""
8490
n = len(all_labels)
8591
num_ignored = max(0, n - num_preserved)
86-
idx = np.random.randint(n, size=num_ignored)
92+
idx = random_state.randint(n, size=num_ignored)
8793
partial_labels = np.array(all_labels, copy=True)
8894
partial_labels[idx] = -1
8995
return Constraints(partial_labels)

metric_learn/itml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3,
166166
self.params.update(num_labeled=num_labeled, num_constraints=num_constraints,
167167
bounds=bounds, A0=A0)
168168

169-
def fit(self, X, labels):
169+
def fit(self, X, labels, random_state=np.random):
170170
"""Create constraints from labels and learn the ITML model.
171171
Needs num_constraints specified in constructor.
172172
@@ -175,12 +175,13 @@ def fit(self, X, labels):
175175
X : (n x d) data matrix
176176
each row corresponds to a single instance
177177
labels : (n) data labels
178+
random_state : a numpy random.seed object to fix the random_state if needed.
178179
"""
179180
num_constraints = self.params['num_constraints']
180181
if num_constraints is None:
181182
num_classes = np.unique(labels)
182183
num_constraints = 20*(len(num_classes))**2
183184

184-
c = Constraints.random_subset(labels, self.params['num_labeled'])
185+
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
185186
return ITML.fit(self, X, c.positive_negative_pairs(num_constraints),
186187
bounds=self.params['bounds'], A0=self.params['A0'])

metric_learn/lsml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
155155
self.params.update(prior=prior, num_labeled=num_labeled,
156156
num_constraints=num_constraints, weights=weights)
157157

158-
def fit(self, X, labels):
158+
def fit(self, X, labels, random_state=np.random):
159159
"""Create constraints from labels and learn the LSML model.
160160
Needs num_constraints specified in constructor.
161161
@@ -164,13 +164,14 @@ def fit(self, X, labels):
164164
X : (n x d) data matrix
165165
each row corresponds to a single instance
166166
labels : (n) data labels
167+
random_state : a numpy random.seed object to fix the random_state if needed.
167168
"""
168169
num_constraints = self.params['num_constraints']
169170
if num_constraints is None:
170171
num_classes = np.unique(labels)
171172
num_constraints = 20*(len(num_classes))**2
172173

173-
c = Constraints.random_subset(labels, self.params['num_labeled'])
174+
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
174175
pairs = c.positive_negative_pairs(num_constraints, same_length=True)
175176
return LSML.fit(self, X, pairs, weights=self.params['weights'],
176177
prior=self.params['prior'])

metric_learn/rca.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, dim=None, num_chunks=100, chunk_size=2):
112112
RCA.__init__(self, dim=dim)
113113
self.params.update(num_chunks=num_chunks, chunk_size=chunk_size)
114114

115-
def fit(self, X, labels):
115+
def fit(self, X, labels, random_state=np.random):
116116
"""Create constraints from labels and learn the RCA model.
117117
Needs num_constraints specified in constructor.
118118
@@ -121,7 +121,8 @@ def fit(self, X, labels):
121121
X : (n x d) data matrix
122122
each row corresponds to a single instance
123123
labels : (n) data labels
124+
random_state : a random.seed object to fix the random_state if needed.
124125
"""
125126
chunks = Constraints(labels).chunks(num_chunks=self.params['num_chunks'],
126-
chunk_size=self.params['chunk_size'])
127+
chunk_size=self.params['chunk_size'], random_state=random_state)
127128
return RCA.fit(self, X, chunks)

metric_learn/sdml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,20 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
9090
'''
9191
self.params.update(num_labeled=num_labeled, num_constraints=num_constraints)
9292

93-
def fit(self, X, labels):
93+
def fit(self, X, labels, random_state=np.random):
9494
"""Create constraints from labels and learn the SDML model.
9595
9696
Parameters
9797
----------
9898
X: data matrix, (n x d)
9999
each row corresponds to a single instance
100100
labels: data labels, (n,) array-like
101+
random_state : a numpy random.seed object to fix the random_state if needed.
101102
"""
102103
num_constraints = self.params['num_constraints']
103104
if num_constraints is None:
104105
num_classes = np.unique(labels)
105106
num_constraints = 20*(len(num_classes))**2
106107

107-
c = Constraints.random_subset(labels, self.params['num_labeled'])
108+
c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state)
108109
return SDML.fit(self, X, c.adjacency_matrix(num_constraints))

0 commit comments

Comments
 (0)