diff --git a/pymc3/data.py b/pymc3/data.py index fcfe19fa7b..c01784edcc 100644 --- a/pymc3/data.py +++ b/pymc3/data.py @@ -2,7 +2,7 @@ import io import os import pkgutil - +import collections import numpy as np import pymc3 as pm import theano.tensor as tt @@ -11,7 +11,8 @@ __all__ = [ 'get_data', 'GeneratorAdapter', - 'Minibatch' + 'Minibatch', + 'align_minibatches' ] @@ -221,6 +222,9 @@ class Minibatch(tt.TensorVariable): >>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10]) >>> assert x.eval().shape == (2, 20, 20, 40, 10) """ + + RNG = collections.defaultdict(list) + @theano.configparser.change_flags(compute_test_value='raise') def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch', random_seed=42, update_shared_f=None, in_memory_size=None): @@ -244,17 +248,21 @@ def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='M inputs=[self.minibatch], outputs=[self]) self.tag.test_value = copy(self.minibatch.tag.test_value) - @staticmethod - def rslice(total, size, seed): + def rslice(self, total, size, seed): if size is None: return slice(None) elif isinstance(size, int): - return (pm.tt_rng(seed) + rng = pm.tt_rng(seed) + Minibatch.RNG[id(self)].append(rng) + return (rng .uniform(size=(size, ), low=0.0, high=pm.floatX(total) - 1e-16) .astype('int64')) else: raise TypeError('Unrecognized size type, %r' % size) + def __del__(self): + del Minibatch.RNG[id(self)] + @staticmethod def make_static_slices(user_size): if user_size is None: @@ -278,12 +286,11 @@ def make_static_slices(user_size): else: raise TypeError('Unrecognized size type, %r' % user_size) - @classmethod - def make_random_slices(cls, in_memory_shape, batch_size, default_random_seed): + def make_random_slices(self, in_memory_shape, batch_size, default_random_seed): if batch_size is None: return [Ellipsis] elif isinstance(batch_size, int): - slc = [cls.rslice(in_memory_shape[0], batch_size, default_random_seed)] + slc = [self.rslice(in_memory_shape[0], batch_size, default_random_seed)] elif isinstance(batch_size, (list, tuple)): def check(t): if t is Ellipsis or t is None: @@ -334,10 +341,10 @@ def check(t): else: shp_end = np.asarray([]) shp_begin = shape[:len(begin)] - slc_begin = [cls.rslice(shp_begin[i], t[0], t[1]) + slc_begin = [self.rslice(shp_begin[i], t[0], t[1]) if t is not None else tt.arange(shp_begin[i]) for i, t in enumerate(begin)] - slc_end = [cls.rslice(shp_end[i], t[0], t[1]) + slc_end = [self.rslice(shp_end[i], t[0], t[1]) if t is not None else tt.arange(shp_end[i]) for i, t in enumerate(end)] slc = slc_begin + mid + slc_end @@ -359,3 +366,16 @@ def clone(self): ret.name = self.name ret.tag = copy(self.tag) return ret + + +def align_minibatches(batches=None): + if batches is None: + for rngs in Minibatch.RNG.values(): + for rng in rngs: + rng.seed() + else: + for b in batches: + if not isinstance(b, Minibatch): + raise TypeError('{b} is not a Minibatch') + for rng in Minibatch.RNG[id(b)]: + rng.seed() diff --git a/pymc3/tests/test_minibatches.py b/pymc3/tests/test_minibatches.py index 96d916cc4c..01f4d13361 100644 --- a/pymc3/tests/test_minibatches.py +++ b/pymc3/tests/test_minibatches.py @@ -313,3 +313,21 @@ def test_cloning_available(self): res1 = theano.clone(res, {gop: shared}) f = theano.function([], res1) assert f() == np.array([100]) + + def test_align(self): + m = pm.Minibatch(np.arange(1000), 1, random_seed=1) + n = pm.Minibatch(np.arange(1000), 1, random_seed=1) + f = theano.function([], [m, n]) + n.eval() # not aligned + a, b = zip(*(f() for _ in range(1000))) + assert a != b + pm.align_minibatches() + a, b = zip(*(f() for _ in range(1000))) + assert a == b + n.eval() # not aligned + pm.align_minibatches([m]) + a, b = zip(*(f() for _ in range(1000))) + assert a != b + pm.align_minibatches([m, n]) + a, b = zip(*(f() for _ in range(1000))) + assert a == b