Skip to content

Commit de1b8c8

Browse files
authored
align minibatches (#2760)
* align minibatches * add simple test * align specific minibatches
1 parent 877bf5e commit de1b8c8

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

pymc3/data.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import io
33
import os
44
import pkgutil
5-
5+
import collections
66
import numpy as np
77
import pymc3 as pm
88
import theano.tensor as tt
@@ -11,7 +11,8 @@
1111
__all__ = [
1212
'get_data',
1313
'GeneratorAdapter',
14-
'Minibatch'
14+
'Minibatch',
15+
'align_minibatches'
1516
]
1617

1718

@@ -221,6 +222,9 @@ class Minibatch(tt.TensorVariable):
221222
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
222223
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
223224
"""
225+
226+
RNG = collections.defaultdict(list)
227+
224228
@theano.configparser.change_flags(compute_test_value='raise')
225229
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',
226230
random_seed=42, update_shared_f=None, in_memory_size=None):
@@ -244,17 +248,21 @@ def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='M
244248
inputs=[self.minibatch], outputs=[self])
245249
self.tag.test_value = copy(self.minibatch.tag.test_value)
246250

247-
@staticmethod
248-
def rslice(total, size, seed):
251+
def rslice(self, total, size, seed):
249252
if size is None:
250253
return slice(None)
251254
elif isinstance(size, int):
252-
return (pm.tt_rng(seed)
255+
rng = pm.tt_rng(seed)
256+
Minibatch.RNG[id(self)].append(rng)
257+
return (rng
253258
.uniform(size=(size, ), low=0.0, high=pm.floatX(total) - 1e-16)
254259
.astype('int64'))
255260
else:
256261
raise TypeError('Unrecognized size type, %r' % size)
257262

263+
def __del__(self):
264+
del Minibatch.RNG[id(self)]
265+
258266
@staticmethod
259267
def make_static_slices(user_size):
260268
if user_size is None:
@@ -278,12 +286,11 @@ def make_static_slices(user_size):
278286
else:
279287
raise TypeError('Unrecognized size type, %r' % user_size)
280288

281-
@classmethod
282-
def make_random_slices(cls, in_memory_shape, batch_size, default_random_seed):
289+
def make_random_slices(self, in_memory_shape, batch_size, default_random_seed):
283290
if batch_size is None:
284291
return [Ellipsis]
285292
elif isinstance(batch_size, int):
286-
slc = [cls.rslice(in_memory_shape[0], batch_size, default_random_seed)]
293+
slc = [self.rslice(in_memory_shape[0], batch_size, default_random_seed)]
287294
elif isinstance(batch_size, (list, tuple)):
288295
def check(t):
289296
if t is Ellipsis or t is None:
@@ -334,10 +341,10 @@ def check(t):
334341
else:
335342
shp_end = np.asarray([])
336343
shp_begin = shape[:len(begin)]
337-
slc_begin = [cls.rslice(shp_begin[i], t[0], t[1])
344+
slc_begin = [self.rslice(shp_begin[i], t[0], t[1])
338345
if t is not None else tt.arange(shp_begin[i])
339346
for i, t in enumerate(begin)]
340-
slc_end = [cls.rslice(shp_end[i], t[0], t[1])
347+
slc_end = [self.rslice(shp_end[i], t[0], t[1])
341348
if t is not None else tt.arange(shp_end[i])
342349
for i, t in enumerate(end)]
343350
slc = slc_begin + mid + slc_end
@@ -359,3 +366,16 @@ def clone(self):
359366
ret.name = self.name
360367
ret.tag = copy(self.tag)
361368
return ret
369+
370+
371+
def align_minibatches(batches=None):
372+
if batches is None:
373+
for rngs in Minibatch.RNG.values():
374+
for rng in rngs:
375+
rng.seed()
376+
else:
377+
for b in batches:
378+
if not isinstance(b, Minibatch):
379+
raise TypeError('{b} is not a Minibatch')
380+
for rng in Minibatch.RNG[id(b)]:
381+
rng.seed()

pymc3/tests/test_minibatches.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,21 @@ def test_cloning_available(self):
313313
res1 = theano.clone(res, {gop: shared})
314314
f = theano.function([], res1)
315315
assert f() == np.array([100])
316+
317+
def test_align(self):
318+
m = pm.Minibatch(np.arange(1000), 1, random_seed=1)
319+
n = pm.Minibatch(np.arange(1000), 1, random_seed=1)
320+
f = theano.function([], [m, n])
321+
n.eval() # not aligned
322+
a, b = zip(*(f() for _ in range(1000)))
323+
assert a != b
324+
pm.align_minibatches()
325+
a, b = zip(*(f() for _ in range(1000)))
326+
assert a == b
327+
n.eval() # not aligned
328+
pm.align_minibatches([m])
329+
a, b = zip(*(f() for _ in range(1000)))
330+
assert a != b
331+
pm.align_minibatches([m, n])
332+
a, b = zip(*(f() for _ in range(1000)))
333+
assert a == b

0 commit comments

Comments
 (0)