-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adds quadratic approximation #4847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The quadratic approximation is an extremely fast way to obtain an estimate of the posterior. It only works if the posterior is unimodal and approximately symmetric. The implementation finds the quadratic approximation and then samples from it to return an arviz.InferenceData object. Finding the exact (approximate) posterior and then sampling from it might seem counter intuitive, but it's done to be compatible with the rest of the codebase and Arviz functionality. The exact (approximate) posterior is also returned as a scipy.stats.multivariate_normal distribution.
pymc3/quadratic_approximation.py
Outdated
---------- | ||
vars: list | ||
List of variables to approximate the posterior for. | ||
n_chains: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this kwarg and should fix it to 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, Arviz will complain if there aren't at least 2 chains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you share an example? I think you'll get trouble for multivariate distributions, where arviz will interpret the leading dimension as the chains, but I would solve that as follows:
data = {'a': np.ones(10), 'b': np.ones((100, 3))}
az.convert_to_inference_data(data) # warning, misinterprets `b` as having 100 chains
az.convert_to_inference_data({k: v[np.newaxis, ...] for k, v in data.items()}) # Good, explicitly sets number of chains to 1.
But maybe I'm misunderstanding the issue!
I'm against including chains since it gives the impression that you might want to run diagnostics on the quality of the returned samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pymc3/quadratic_approximation.py
Outdated
|
||
Returns | ||
------- | ||
(arviz.InferenceData, scipy.stats.multivariate_normal): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also list these items individually, you can look at other doc strings for examples for multiple return elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find any examples. I tried a new format. Please provide a reference if it's not correct :)
This looks pretty good. Needs a line in the release notes. I wonder if we should make this available through Really though this probably fits into the |
Codecov Report
@@ Coverage Diff @@
## main #4847 +/- ##
==========================================
- Coverage 72.32% 66.90% -5.43%
==========================================
Files 85 86 +1
Lines 13884 13905 +21
==========================================
- Hits 10042 9303 -739
- Misses 3842 4602 +760
|
I suggest adding a test with multivariate variables and / or variables with different dimensions (larger than 1 dimension). Not sure how well that mean concatenation will work in those cases. |
I don't think that would be the best. But we could perhaps add a general I am planning to add a "grid approximation" function that works a lot like this quadratic in spirit (deterministic evaluation + samples to returns an inferencedata), and would place it there as well. |
InferenceData with samples from the approximate posterior, multivariate normal posterior approximation | ||
|
||
""" | ||
map = find_MAP(vars=vars) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any arguments to find_MAP
that we might want to allow the user to pass? I.e. map = find_MAP(vars=vars, **map_kwargs)
? Same for the Hessian
. If not, that's fine (even better)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I think I'd say no. It's a leaky abstraction. If someone raises an issue where they need to pass args to find_MAP we can think about how to best do it at that time. That's my 2 cents.
That's why I made |
Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
I've adressed the PR comments as best as I could, ran linting and added a line in release notes. Please take another look @twiecki I will add an example notebook once merged. |
Agreed it doesn't fit well in sample. A humble suggestion might be to make an inference module and place the various types of posterior inference there, e.g. MCMC, VI, this, grid approx., etc. |
I like the |
pymc3/quadratic_approximation.py
Outdated
---------- | ||
vars: list | ||
List of variables to approximate the posterior for. | ||
n_chains: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you share an example? I think you'll get trouble for multivariate distributions, where arviz will interpret the leading dimension as the chains, but I would solve that as follows:
data = {'a': np.ones(10), 'b': np.ones((100, 3))}
az.convert_to_inference_data(data) # warning, misinterprets `b` as having 100 chains
az.convert_to_inference_data({k: v[np.newaxis, ...] for k, v in data.items()}) # Good, explicitly sets number of chains to 1.
But maybe I'm misunderstanding the issue!
I'm against including chains since it gives the impression that you might want to run diagnostics on the quality of the returned samples.
(sorry, submitted a review late, but excited to see this!) |
@ColCarroll I removed the chains. PTAL :) |
I've spent the entire day trying to reproduce the test errors and have failed. I simply cannot create a working development environment. First I spent countless hours fixing C import
Some bug deep in the FORTRAN code of LBFGS, which is not the error in CI. At this point I've given up. The code works, as you can see here https://colab.research.google.com/drive/1DTe7QchyW-wpbUmulzY27lBTRFQXzhm0?usp=sharing (at least in PyMC3 3.11.2) |
Closing due to inactivity, feel free to reopen. |
The quadratic approximation is an extremely fast way to obtain an estimate of the posterior. It only works if the posterior is unimodal and approximately symmetric.
The implementation finds the quadratic approximation and then samples from it to return an arviz.InferenceData object. Finding the exact (approximate) posterior and then sampling from it might seem counter intuitive, but it's done to be compatible with the rest of the codebase and Arviz functionality. The exact (approximate) posterior is also returned as a scipy.stats.multivariate_normal distribution.
Thank your for opening a PR!
Depending on what your PR does, here are a few things you might want to address in the description: