@@ -19,8 +19,8 @@ def __init__(self, partial_labels):
19
19
self .known_label_idx , = np .where (partial_labels >= 0 )
20
20
self .known_labels = partial_labels [self .known_label_idx ]
21
21
22
- def adjacency_matrix (self , num_constraints ):
23
- a , b , c , d = self .positive_negative_pairs (num_constraints )
22
+ def adjacency_matrix (self , num_constraints , random_state = np . random ):
23
+ a , b , c , d = self .positive_negative_pairs (num_constraints , random_state = random_state )
24
24
row = np .concatenate ((a , c ))
25
25
col = np .concatenate ((b , d ))
26
26
data = np .ones_like (row , dtype = int )
@@ -29,48 +29,51 @@ def adjacency_matrix(self, num_constraints):
29
29
# symmetrize
30
30
return adj + adj .T
31
31
32
- def positive_negative_pairs (self , num_constraints , same_length = False ):
33
- a , b = self ._pairs (num_constraints , same_label = True )
34
- c , d = self ._pairs (num_constraints , same_label = False )
32
+ def positive_negative_pairs (self , num_constraints , same_length = False , random_state = np . random ):
33
+ a , b = self ._pairs (num_constraints , same_label = True , random_state = random_state )
34
+ c , d = self ._pairs (num_constraints , same_label = False , random_state = random_state )
35
35
if same_length and len (a ) != len (c ):
36
36
n = min (len (a ), len (c ))
37
37
return a [:n ], b [:n ], c [:n ], d [:n ]
38
38
return a , b , c , d
39
39
40
- def _pairs (self , num_constraints , same_label = True , max_iter = 10 ):
40
+ def _pairs (self , num_constraints , same_label = True , max_iter = 10 , random_state = np . random ):
41
41
num_labels = len (self .known_labels )
42
42
ab = set ()
43
43
it = 0
44
44
while it < max_iter and len (ab ) < num_constraints :
45
45
nc = num_constraints - len (ab )
46
- for aidx in np . random .randint (num_labels , size = nc ):
46
+ for aidx in random_state .randint (num_labels , size = nc ):
47
47
if same_label :
48
48
mask = self .known_labels [aidx ] == self .known_labels
49
49
mask [aidx ] = False # avoid identity pairs
50
50
else :
51
51
mask = self .known_labels [aidx ] != self .known_labels
52
52
b_choices , = np .where (mask )
53
53
if len (b_choices ) > 0 :
54
- ab .add ((aidx , np . random .choice (b_choices )))
54
+ ab .add ((aidx , random_state .choice (b_choices )))
55
55
it += 1
56
56
if len (ab ) < num_constraints :
57
57
warnings .warn ("Only generated %d %s constraints (requested %d)" % (
58
58
len (ab ), 'positive' if same_label else 'negative' , num_constraints ))
59
59
ab = np .array (list (ab )[:num_constraints ], dtype = int )
60
60
return self .known_label_idx [ab .T ]
61
61
62
- def chunks (self , num_chunks = 100 , chunk_size = 2 ):
62
+ def chunks (self , num_chunks = 100 , chunk_size = 2 , random_state = np .random ):
63
+ """
64
+ the random state object to be passed must be a numpy random seed
65
+ """
63
66
chunks = - np .ones_like (self .known_label_idx , dtype = int )
64
67
uniq , lookup = np .unique (self .known_labels , return_inverse = True )
65
68
all_inds = [set (np .where (lookup == c )[0 ]) for c in xrange (len (uniq ))]
66
69
idx = 0
67
70
while idx < num_chunks and all_inds :
68
- c = random .randint (0 , len (all_inds )- 1 )
71
+ c = random_state .randint (0 , high = len (all_inds )- 1 )
69
72
inds = all_inds [c ]
70
73
if len (inds ) < chunk_size :
71
74
del all_inds [c ]
72
75
continue
73
- ii = random . sample ( inds , chunk_size )
76
+ ii = random_state . choice ( list ( inds ) , chunk_size , replace = False )
74
77
inds .difference_update (ii )
75
78
chunks [ii ] = idx
76
79
idx += 1
@@ -80,10 +83,13 @@ def chunks(self, num_chunks=100, chunk_size=2):
80
83
return chunks
81
84
82
85
@staticmethod
83
- def random_subset (all_labels , num_preserved = np .inf ):
86
+ def random_subset (all_labels , num_preserved = np .inf , random_state = np .random ):
87
+ """
88
+ the random state object to be passed must be a numpy random seed
89
+ """
84
90
n = len (all_labels )
85
91
num_ignored = max (0 , n - num_preserved )
86
- idx = np . random .randint (n , size = num_ignored )
92
+ idx = random_state .randint (n , size = num_ignored )
87
93
partial_labels = np .array (all_labels , copy = True )
88
94
partial_labels [idx ] = - 1
89
95
return Constraints (partial_labels )
0 commit comments