2
2
import io
3
3
import os
4
4
import pkgutil
5
-
5
+ import collections
6
6
import numpy as np
7
7
import pymc3 as pm
8
8
import theano .tensor as tt
11
11
__all__ = [
12
12
'get_data' ,
13
13
'GeneratorAdapter' ,
14
- 'Minibatch'
14
+ 'Minibatch' ,
15
+ 'align_minibatches'
15
16
]
16
17
17
18
@@ -221,6 +222,9 @@ class Minibatch(tt.TensorVariable):
221
222
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
222
223
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
223
224
"""
225
+
226
+ RNG = collections .defaultdict (list )
227
+
224
228
@theano .configparser .change_flags (compute_test_value = 'raise' )
225
229
def __init__ (self , data , batch_size = 128 , dtype = None , broadcastable = None , name = 'Minibatch' ,
226
230
random_seed = 42 , update_shared_f = None , in_memory_size = None ):
@@ -244,17 +248,21 @@ def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='M
244
248
inputs = [self .minibatch ], outputs = [self ])
245
249
self .tag .test_value = copy (self .minibatch .tag .test_value )
246
250
247
- @staticmethod
248
- def rslice (total , size , seed ):
251
+ def rslice (self , total , size , seed ):
249
252
if size is None :
250
253
return slice (None )
251
254
elif isinstance (size , int ):
252
- return (pm .tt_rng (seed )
255
+ rng = pm .tt_rng (seed )
256
+ Minibatch .RNG [id (self )].append (rng )
257
+ return (rng
253
258
.uniform (size = (size , ), low = 0.0 , high = pm .floatX (total ) - 1e-16 )
254
259
.astype ('int64' ))
255
260
else :
256
261
raise TypeError ('Unrecognized size type, %r' % size )
257
262
263
+ def __del__ (self ):
264
+ del Minibatch .RNG [id (self )]
265
+
258
266
@staticmethod
259
267
def make_static_slices (user_size ):
260
268
if user_size is None :
@@ -278,12 +286,11 @@ def make_static_slices(user_size):
278
286
else :
279
287
raise TypeError ('Unrecognized size type, %r' % user_size )
280
288
281
- @classmethod
282
- def make_random_slices (cls , in_memory_shape , batch_size , default_random_seed ):
289
+ def make_random_slices (self , in_memory_shape , batch_size , default_random_seed ):
283
290
if batch_size is None :
284
291
return [Ellipsis ]
285
292
elif isinstance (batch_size , int ):
286
- slc = [cls .rslice (in_memory_shape [0 ], batch_size , default_random_seed )]
293
+ slc = [self .rslice (in_memory_shape [0 ], batch_size , default_random_seed )]
287
294
elif isinstance (batch_size , (list , tuple )):
288
295
def check (t ):
289
296
if t is Ellipsis or t is None :
@@ -334,10 +341,10 @@ def check(t):
334
341
else :
335
342
shp_end = np .asarray ([])
336
343
shp_begin = shape [:len (begin )]
337
- slc_begin = [cls .rslice (shp_begin [i ], t [0 ], t [1 ])
344
+ slc_begin = [self .rslice (shp_begin [i ], t [0 ], t [1 ])
338
345
if t is not None else tt .arange (shp_begin [i ])
339
346
for i , t in enumerate (begin )]
340
- slc_end = [cls .rslice (shp_end [i ], t [0 ], t [1 ])
347
+ slc_end = [self .rslice (shp_end [i ], t [0 ], t [1 ])
341
348
if t is not None else tt .arange (shp_end [i ])
342
349
for i , t in enumerate (end )]
343
350
slc = slc_begin + mid + slc_end
@@ -359,3 +366,16 @@ def clone(self):
359
366
ret .name = self .name
360
367
ret .tag = copy (self .tag )
361
368
return ret
369
+
370
+
371
+ def align_minibatches (batches = None ):
372
+ if batches is None :
373
+ for rngs in Minibatch .RNG .values ():
374
+ for rng in rngs :
375
+ rng .seed ()
376
+ else :
377
+ for b in batches :
378
+ if not isinstance (b , Minibatch ):
379
+ raise TypeError ('{b} is not a Minibatch' )
380
+ for rng in Minibatch .RNG [id (b )]:
381
+ rng .seed ()
0 commit comments