Description
Minibatch creates a stochastic view of the underlying dataset using some pretty funky Aesara magic.
In a straightforward implementation it would do something like:
import pymc as pm
import numpy as np
data = pm.Normal.dist(size=(100, 2)).eval()
with pm.Model() as m:
data = pm.Data("data", data)
mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2)).sort()
minibatch_data = data[mb_slice_start: mb_slice_end]
x = pm.Normal("x", 0, 1, observed=minibatch_data)
And then PyMC also figures out that x
lopg has to be scaled by the size of the unobserved values in each iteration.
One pressing issues is that it relies on the deprecated MRG sampler #4523, which also has no support for JAX or Numba. It isn't as simple a matter as switching to the default RandomStream
which returns RandomVariable
s, because we treat the shared RNG variables used by RandomVariables in a special manner inside aesaraf.compile_pymc
(the function through which all PyMC functions go through before calling aesara.function
). Specifically we always set the values of the distinct shared RNG variables to something new, based on what random_seed
is passed to compile_pymc
.
This is a problem when you need to synchronize multiple mini-batches of data, say when you have a linear regression model with a minibatch of x
and y
. The way it currently works is that you set the same random_seed
to each (default is hard-coded 42), and then rely on this initial value being the same so that the endpoints of the slice stay synchronized. But if we were to use the standard RandomStream, and pass the graph to compile_pymc
, the two RNGs associated with the two slices would be overwritten to different values, regardless of the fact that they started with the same value.
Anyway, this trick might also not have been enough, since there was a need to introduce align_minibatches
in #2760:
Lines 463 to 474 in faebc60
The big issue here is that we are not representing that the slice endpoints are the same symbolically. The current graph looks something like:
with pm.Model() as m:
x = pm.Data("x", x)
y = pm.Data("y", y)
rng = aesara.shared(np.random.default_rng(42))
mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
minibatch_x = x[mb_slice_start: mb_slice_end]
rng = aesara.shared(np.random.default_rng(42))
mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
minibatch_y = y[mb_slice_start: mb_slice_end]
obs = pm.Normal("obs", x, 1, observed=minibatch_y)
The two rngs were set to the same initial value, but other than that they are completely different as far as Aesara cares. The correct graph would be:
with pm.Model() as m:
x = pm.Data("x", x)
y = pm.Data("y", y)
rng = aesara.shared(np.random.default_rng(42))
mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], shape=(2,), rng=rng).sort()
minibatch_x = x[mb_slice_start: mb_slice_end]
minibatch_y = y[mb_slice_start: mb_slice_end]
obs = pm.Normal("obs", x, 1, observed=minibatch_y)
That is, using the same slicing RV for both mini-batches. Regardless of the seed value that is set by compile_pymc
, the two minibatches will always be compatible.
The incremental proposal
Refactor Minibatch so that it can accept multiple variables at once, which will share the same random slices. Then start using RandomVariables instead of MRG stuff, and get rid of seed at the minibatch
level
with pm.Model() as m:
x, y = pm.Minibatch({x: (10, ...), y: (10)})
Maybe even sprinkle some dims if you want
The radical proposal
Offer the same API, but implement the minibatch view with the more straightforward code from the pseudo-examples in this issue.
No need to keep a global list of all the RNGs:
Line 304 in faebc60
The whole Op view magic shouldn't be needed AFAICT. The only thing is that right now PyMC complains if you pass something to observed other than a shared variable or constant data. The concern here is that the observations should not depend on other model RVs, but that's easy to check with aesara.graph.ancestors()
Line 339 in faebc60
The other issue is that we don't allow orphan RVs in the logp graph, but we can easily create a subtype like MinibatchUniformRV
, that is allowed, just like SimulatorRV
s are allowed.
This should make the Minibatch code much more readable and maintainable, as well as compatible with the non-C backends (I have no idea if either supports View
)