diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 848e681f8c..dcea7f56d7 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -154,17 +154,17 @@ def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]: if sampler_stats is None: return [] - treedepth = sampler_stats.get("tree_depth", None) - if treedepth is None: + rmtd = sampler_stats.get("reached_max_treedepth", None) + if rmtd is None: return [] warnings = [] - for c in treedepth.chain: - if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05: + for c in rmtd.chain: + if sum(rmtd.sel(chain=c)) / rmtd.sizes["draw"] > 0.05: warnings.append( SamplerWarning( WarningType.TREEDEPTH, - f"Chain {c} reached the maximum tree depth." + f"Chain {int(c)} reached the maximum tree depth." " Increase `max_treedepth`, increase `target_accept` or reparameterize.", "warn", ) diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index 7dba129a37..b5e99a09ff 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -31,6 +31,17 @@ def test_warn_divergences(): assert "2 divergences after tuning" in warns[0].message +def test_warn_treedepth(): + idata = arviz.from_dict( + sample_stats={ + "reached_max_treedepth": np.array([[0, 0, 0], [0, 1, 0]]).astype(bool), + } + ) + warns = convergence.warn_treedepth(idata) + assert len(warns) == 1 + assert "Chain 1 reached the maximum tree depth" in warns[0].message + + def test_log_warning_stats(caplog): s1 = dict(warning="Temperature too low!") s2 = dict(warning="Temperature too high!")