|
7 | 7 | from pytensor.tensor.var import Variable
|
8 | 8 | from scipy.interpolate import griddata
|
9 | 9 | from scipy.signal import savgol_filter
|
10 |
| -from scipy.stats import pearsonr |
| 10 | +from scipy.stats import pearsonr, norm |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
|
@@ -92,10 +92,11 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None):
|
92 | 92 |
|
93 | 93 | for idx, (essi, rhati) in enumerate(zip(ess, rhat)):
|
94 | 94 | kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"})
|
95 |
| - ax[0].axvline(ess_threshold, color="k", ls="--") |
| 95 | + ax[0].axvline(ess_threshold, color="0.7", ls="--") |
96 | 96 | kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"})
|
97 |
| - ax[1].axvline(1.01, color="0.6", ls="--") |
98 |
| - ax[1].axvline(1.05, color="k", ls="--") |
| 97 | + # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile |
| 98 | + # scaled by the sample size and use it as a threshold. |
| 99 | + ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--") |
99 | 100 |
|
100 | 101 | ax[0].set_xlabel("ESS")
|
101 | 102 | ax[1].set_xlabel("R-hat")
|
|
0 commit comments