Skip to content

Commit c65147c

Browse files
Removes chains argument
1 parent e15c065 commit c65147c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pymc3/quadratic_approximation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,25 @@
2323
from pymc3.tuning import find_hessian, find_MAP
2424

2525

26-
def quadratic_approximation(vars, n_chains=2, n_samples=10_000):
26+
def quadratic_approximation(vars, n_samples=10_000):
2727
"""Finds the quadratic approximation to the posterior, also known as the Laplace approximation.
2828
2929
NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables.
30-
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation.
30+
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all samples are from exactly the same distribution, the posterior quadratic approximation.
3131
Use at your own risk.
3232
3333
See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background.
3434
35-
Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples, so the notion of chains is meaningless, and is only included for downstream compatibility with Arviz.
35+
Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples.
3636
3737
Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution.
3838
3939
Parameters
4040
----------
4141
vars: list
4242
List of variables to approximate the posterior for.
43-
n_chains: int
44-
How many chains to simulate.
4543
n_samples: int
46-
How many samples to sample from the approximate posterior for each chain.
44+
How many samples to sample from the approximate posterior.
4745
4846
Returns
4947
-------
@@ -57,7 +55,7 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000):
5755
cov = np.linalg.inv(H)
5856
mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars])
5957
posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov)
60-
draws = posterior.rvs((n_chains, n_samples))
58+
draws = posterior.rvs(n_samples)[np.newaxis, ...]
6159
samples = {}
6260
i = 0
6361
for v in vars:

pymc3/tests/test_quadratic_approximation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ class TestQuadraticApproximation(SeededTest):
1010
def setup_method(self):
1111
super().setup_method()
1212

13-
def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance():
13+
def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance(
14+
self,
15+
):
1416
y = np.array([2642, 3503, 4358])
1517
n = y.size
1618

@@ -27,7 +29,7 @@ def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean
2729
assert np.allclose(posterior.mean, bda_map)
2830
assert np.allclose(posterior.cov, bda_cov, atol=1e-4)
2931

30-
def test_hdi_contains_parameters_in_linear_regression():
32+
def test_hdi_contains_parameters_in_linear_regression(self):
3133
N = 100
3234
M = 2
3335
sigma = 0.2

0 commit comments

Comments
 (0)