Skip to content

Use RandomVariables for Minibatch #6277

Closed
@ricardoV94

Description

@ricardoV94

Minibatch creates a stochastic view of the underlying dataset using some pretty funky Aesara magic.

In a straightforward implementation it would do something like:

import pymc as pm
import numpy as np

data = pm.Normal.dist(size=(100, 2)).eval()

with pm.Model() as m:
  data = pm.Data("data", data)

  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2)).sort()
  minibatch_data = data[mb_slice_start: mb_slice_end]
  
  x = pm.Normal("x", 0, 1, observed=minibatch_data)

And then PyMC also figures out that x lopg has to be scaled by the size of the unobserved values in each iteration.

One pressing issues is that it relies on the deprecated MRG sampler #4523, which also has no support for JAX or Numba. It isn't as simple a matter as switching to the default RandomStream which returns RandomVariables, because we treat the shared RNG variables used by RandomVariables in a special manner inside aesaraf.compile_pymc (the function through which all PyMC functions go through before calling aesara.function). Specifically we always set the values of the distinct shared RNG variables to something new, based on what random_seed is passed to compile_pymc.

This is a problem when you need to synchronize multiple mini-batches of data, say when you have a linear regression model with a minibatch of x and y. The way it currently works is that you set the same random_seed to each (default is hard-coded 42), and then rely on this initial value being the same so that the endpoints of the slice stay synchronized. But if we were to use the standard RandomStream, and pass the graph to compile_pymc, the two RNGs associated with the two slices would be overwritten to different values, regardless of the fact that they started with the same value.

Anyway, this trick might also not have been enough, since there was a need to introduce align_minibatches in #2760:

pymc/pymc/data.py

Lines 463 to 474 in faebc60

def align_minibatches(batches=None):
if batches is None:
for rngs in Minibatch.RNG.values():
for rng in rngs:
rng.seed()
else:
for b in batches:
if not isinstance(b, Minibatch):
raise TypeError(f"{b} is not a Minibatch")
for rng in Minibatch.RNG[id(b)]:
rng.seed()

https://github.com/pymc-devs/pymc/blob/2296350959e4035b4f1ee13fab88014d4f0fa545/pymc/tests/test_data.py#L754-L773

The big issue here is that we are not representing that the slice endpoints are the same symbolically. The current graph looks something like:

with pm.Model() as m:
  x = pm.Data("x", x)
  y = pm.Data("y", y)

  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
  minibatch_x = x[mb_slice_start: mb_slice_end]

  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
  minibatch_y = y[mb_slice_start: mb_slice_end]
  
  obs = pm.Normal("obs", x, 1, observed=minibatch_y)

The two rngs were set to the same initial value, but other than that they are completely different as far as Aesara cares. The correct graph would be:

with pm.Model() as m:
  x = pm.Data("x", x)
  y = pm.Data("y", y)
  
  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], shape=(2,), rng=rng).sort()
  minibatch_x = x[mb_slice_start: mb_slice_end]
  minibatch_y = y[mb_slice_start: mb_slice_end]
  
  obs = pm.Normal("obs", x, 1, observed=minibatch_y)

That is, using the same slicing RV for both mini-batches. Regardless of the seed value that is set by compile_pymc, the two minibatches will always be compatible.

The incremental proposal

Refactor Minibatch so that it can accept multiple variables at once, which will share the same random slices. Then start using RandomVariables instead of MRG stuff, and get rid of seed at the minibatch level

with pm.Model() as m:
  x, y = pm.Minibatch({x: (10, ...), y: (10)})

Maybe even sprinkle some dims if you want

The radical proposal

Offer the same API, but implement the minibatch view with the more straightforward code from the pseudo-examples in this issue.

No need to keep a global list of all the RNGs:

RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]

The whole Op view magic shouldn't be needed AFAICT. The only thing is that right now PyMC complains if you pass something to observed other than a shared variable or constant data. The concern here is that the observations should not depend on other model RVs, but that's easy to check with aesara.graph.ancestors()

Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self])

The other issue is that we don't allow orphan RVs in the logp graph, but we can easily create a subtype like MinibatchUniformRV, that is allowed, just like SimulatorRVs are allowed.

This should make the Minibatch code much more readable and maintainable, as well as compatible with the non-C backends (I have no idea if either supports View)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions