From 0775f3e7d9b2ab790d07527e45313ee28bbe64a6 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 11 Mar 2023 22:54:28 +0100 Subject: [PATCH 1/2] Fix `warn_treedepth` looking at the wrong stat Closes #6587 --- pymc/stats/convergence.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 848e681f8c..dcea7f56d7 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -154,17 +154,17 @@ def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]: if sampler_stats is None: return [] - treedepth = sampler_stats.get("tree_depth", None) - if treedepth is None: + rmtd = sampler_stats.get("reached_max_treedepth", None) + if rmtd is None: return [] warnings = [] - for c in treedepth.chain: - if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05: + for c in rmtd.chain: + if sum(rmtd.sel(chain=c)) / rmtd.sizes["draw"] > 0.05: warnings.append( SamplerWarning( WarningType.TREEDEPTH, - f"Chain {c} reached the maximum tree depth." + f"Chain {int(c)} reached the maximum tree depth." " Increase `max_treedepth`, increase `target_accept` or reparameterize.", "warn", ) From 8297c5b53f19b83b88b64cb892243fe481a898c3 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 12 Mar 2023 13:38:28 +0100 Subject: [PATCH 2/2] Add test for `warn_treedepth` function --- tests/stats/test_convergence.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index 7dba129a37..b5e99a09ff 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -31,6 +31,17 @@ def test_warn_divergences(): assert "2 divergences after tuning" in warns[0].message +def test_warn_treedepth(): + idata = arviz.from_dict( + sample_stats={ + "reached_max_treedepth": np.array([[0, 0, 0], [0, 1, 0]]).astype(bool), + } + ) + warns = convergence.warn_treedepth(idata) + assert len(warns) == 1 + assert "Chain 1 reached the maximum tree depth" in warns[0].message + + def test_log_warning_stats(caplog): s1 = dict(warning="Temperature too low!") s2 = dict(warning="Temperature too high!")