Skip to content

Commit a0b1775

Browse files
Get tune mask from "tune" or "*__tune" stats
This restores compatibility with MCMC runs from PyMC >= 5.7.0. Closes #102
1 parent 96bd3c3 commit a0b1775

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
except ModuleNotFoundError:
1313
pass
1414

15-
__version__ = "0.5.1"
15+
__version__ = "0.5.2"
1616
__all__ = [
1717
"NumPyBackend",
1818
"Backend",

mcbackend/core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
_log = logging.getLogger(__file__)
2424

2525

26+
__all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend")
27+
28+
2629
def is_rigid(nshape: Optional[Shape]):
2730
"""Determines wheather the shape is constant.
2831
@@ -133,6 +136,20 @@ def sample_stats(self) -> Dict[str, Variable]:
133136
return {var.name: var for var in self.rmeta.sample_stats}
134137

135138

139+
def get_tune_mask(chain: Chain, slc: slice = slice(None)) -> numpy.ndarray:
140+
"""Load the tuning mask from either a ``"tune"``, or a ``"*__tune"`` stat.
141+
142+
Raises
143+
------
144+
KeyError
145+
When no matching stat is found.
146+
"""
147+
for sname in chain.sample_stats:
148+
if sname.endswith("__tune") or sname == "tune":
149+
return chain.get_stats(sname, slc).astype(bool)
150+
raise KeyError("No tune stat found.")
151+
152+
136153
class Run:
137154
"""A handle on one MCMC run."""
138155

@@ -231,14 +248,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
231248
slc = slice(0, min_clen)
232249

233250
# Obtain a mask by which draws can be split into warmup/posterior
234-
if "tune" in chain.sample_stats:
235-
tune = chain.get_stats("tune", slc).astype(bool)
236-
else:
251+
try:
252+
# Use the same slice to avoid shape issues in case the chain is still active
253+
tune = get_tune_mask(chain, slc)
254+
except KeyError:
237255
if c == 0:
238256
_log.warning(
239257
"No 'tune' stat found. Assuming all iterations are posterior draws."
240258
)
241-
tune = numpy.full((chain_lengths[chain.cid],), False)
259+
tune = numpy.full((slc.stop,), False)
242260

243261
# Split all variables draws into warmup/posterior
244262
for var in variables:

mcbackend/test_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,12 @@ def test__get_chains(self):
277277
assert len(chain) == 1
278278
pass
279279

280-
def test__to_inferencedata(self):
280+
@pytest.mark.parametrize("tstatname", ["tune", "sampler__tune", "nottune"])
281+
def test__to_inferencedata(self, tstatname, caplog):
281282
rmeta = make_runmeta(
282283
flexibility=False,
283284
sample_stats=[
284-
Variable("tune", "bool"),
285+
Variable(tstatname, "bool"),
285286
Variable("sampler_0__logp", "float32"),
286287
Variable("warning", "str"),
287288
],
@@ -294,15 +295,22 @@ def test__to_inferencedata(self):
294295
draws = [make_draw(rmeta.variables) for _ in range(n)]
295296
stats = [make_draw(rmeta.sample_stats) for _ in range(n)]
296297
for i, (d, s) in enumerate(zip(draws, stats)):
297-
s["tune"] = i < 4
298+
s[tstatname] = i < 4
298299
chain.append(d, s)
299300

300301
idata = run.to_inferencedata()
301302
assert isinstance(idata, arviz.InferenceData)
302303
assert idata.warmup_posterior.dims["chain"] == 1
303-
assert idata.warmup_posterior.dims["draw"] == 4
304304
assert idata.posterior.dims["chain"] == 1
305-
assert idata.posterior.dims["draw"] == 6
305+
if tstatname == "nottune":
306+
# Splitting into warmup/posterior requires a tune stat!
307+
assert any("No 'tune' stat" in r.message for r in caplog.records)
308+
assert idata.warmup_posterior.dims["draw"] == 0
309+
assert idata.posterior.dims["draw"] == 10
310+
else:
311+
assert idata.warmup_posterior.dims["draw"] == 4
312+
assert idata.posterior.dims["draw"] == 6
313+
306314
for var in rmeta.variables:
307315
assert var.name in set(idata.posterior.keys())
308316
for svar in rmeta.sample_stats:

0 commit comments

Comments
 (0)