Stochastic gradients in pytensor #1424
Unanswered
jessegrabowski
asked this question in
Q&A
Replies: 2 comments
-
CC @aseyboldt |
Beta Was this translation helpful? Give feedback.
0 replies
-
Also @zaxtax had some thoughts on this during the VI hack |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Description
There are many problems in machine learning that require differentiating through random variables. This specifically came up in the context of pymc-devs/pymc#7799, but it was also implicated in #555.
Right now,
pt.grad
refuses to go through a random variable. That's probably the correct behavior. But we could have another function,pt.stochastic_grad
, that would do it and return a stochastic gradient. It would also take ann_samples
argument, since what we're actually computing is a sample-based MCMC estimate of the gradients with respect to parameters.Most often, the RV would have to know a "reparameterization trick" to split its parameters from the source of randomness. The canonical example is the non-centered normal parameterization. Given a loss function$\mathcal L$ that depends on $x \sim N(\mu, \sigma)$ , the proposed $x = \mu + \sigma z, \quad z \sim N(0,1)$ , so that now the gradient contribution of $g(x, \theta)$ can be estimated:
pt.stochastic_grad
would compute the gradient of the expected loss given the RVs: $\nabla_\theta \mathbb{E}x [\mathcal{L(g(x, \theta)})] = \mathbb{E} \nabla\theta \mathcal{L}(g(x, \theta))$. The so-called reparameterization trick just does a non-centered parameterization,And the (expected) sensitivity equations for the parameters of the normal are:
It would be easy enough for a normal_rv to know this, and to supply these formulas when requested to by the hypothetical
pt.stochastic_grad
.I guess other RVs also have reparameterizations (beta, dirichlet, gamma, ...?), but in some cases, there are multiple options but it's not clear which one is best to use in what cases. Some thought would have to be given to how to handle that.
When a reparameterization doesn't exist, there are other, higher-variance options to compute the expected gradients (the REINFORCE gradients, for example). We could offer these as a fallback.
Basically, this issue is proposing this API, and inviting some discussion on whether we want this type of feature, and how to do it if so. The
pt.stochastic_grad
function would be novel as far as I know. Other packages require that you explicitly generate samples in your computation graph. For example,torch
offersnormal(mu, sigma).rsample(n_draws)
, which generates samples using reparameterization trick, so the standardloss.backward()
works. Here the user can't "accidentally" trigger stochastic gradients (because you have to call rsample instead of sample).I'm less familiar with how numpyro works, but I believe that something like
numpyro.sample("z", dist.Normal(mu, sigma))
automatically implies reparameterization trick if it's available. They don't have a special idiom likersample
for when it will or won't be used.Beta Was this translation helpful? Give feedback.
All reactions