Skip to content

Commit 7f949ed

Browse files
More robust tune stat fetching and type hints
The mcmc module relies on the `"tune"` stat to figure out the number of tune/draw iterations post sampling. These changes remove reliance on any weird squeeze-cat magic. Co-authored-by: Virgile Andreani <armavica@ulminfo.fr>
1 parent 111fae3 commit 7f949ed

File tree

3 files changed

+64
-39
lines changed

3 files changed

+64
-39
lines changed

pymc/backends/base.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class IBaseTrace(ABC, Sized):
5858
varnames: List[str]
5959
"""Names of tracked variables."""
6060

61-
sampler_vars: List[Dict[str, type]]
61+
sampler_vars: List[Dict[str, Union[type, np.dtype]]]
6262
"""Sampler stats for each sampler."""
6363

6464
def __len__(self):
@@ -79,23 +79,27 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
7979
"""
8080
raise NotImplementedError()
8181

82-
def get_sampler_stats(self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1):
82+
def get_sampler_stats(
83+
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
84+
) -> np.ndarray:
8385
"""Get sampler statistics from the trace.
8486
8587
Parameters
8688
----------
87-
stat_name: str
88-
sampler_idx: int or None
89-
burn: int
90-
thin: int
89+
stat_name : str
90+
Name of the stat to fetch.
91+
sampler_idx : int or None
92+
Index of the sampler to get the stat from.
93+
burn : int
94+
Draws to skip from the start.
95+
thin : int
96+
Stepsize for the slice.
9197
9298
Returns
9399
-------
94-
If the `sampler_idx` is specified, return the statistic with
95-
the given name in a numpy array. If it is not specified and there
96-
is more than one sampler that provides this statistic, return
97-
a numpy array of shape (m, n), where `m` is the number of
98-
such samplers, and `n` is the number of samples.
100+
stats : np.ndarray
101+
If `sampler_idx` was specified, the shape should be `(draws,)`.
102+
Otherwise, the shape should be `(draws, samplers)`.
99103
"""
100104
raise NotImplementedError()
101105

@@ -220,23 +224,31 @@ def __getitem__(self, idx):
220224
except (ValueError, TypeError): # Passed variable or variable name.
221225
raise ValueError("Can only index with slice or integer")
222226

223-
def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
227+
def get_sampler_stats(
228+
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
229+
) -> np.ndarray:
224230
"""Get sampler statistics from the trace.
225231
232+
Note: This implementation attempts to squeeze object arrays into a consistent dtype,
233+
# which can change their shape in hard-to-predict ways.
234+
# See https://github.com/pymc-devs/pymc/issues/6207
235+
226236
Parameters
227237
----------
228-
stat_name: str
229-
sampler_idx: int or None
230-
burn: int
231-
thin: int
238+
stat_name : str
239+
Name of the stat to fetch.
240+
sampler_idx : int or None
241+
Index of the sampler to get the stat from.
242+
burn : int
243+
Draws to skip from the start.
244+
thin : int
245+
Stepsize for the slice.
232246
233247
Returns
234248
-------
235-
If the `sampler_idx` is specified, return the statistic with
236-
the given name in a numpy array. If it is not specified and there
237-
is more than one sampler that provides this statistic, return
238-
a numpy array of shape (m, n), where `m` is the number of
239-
such samplers, and `n` is the number of samples.
249+
stats : np.ndarray
250+
If `sampler_idx` was specified, the shape should be `(draws,)`.
251+
Otherwise, the shape should be `(draws, samplers)`.
240252
"""
241253
if sampler_idx is not None:
242254
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
@@ -254,14 +266,16 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
254266

255267
if vals.dtype == np.dtype(object):
256268
try:
257-
vals = np.vstack(vals)
269+
vals = np.vstack(list(vals))
258270
except ValueError:
259271
# Most likely due to non-identical shapes. Just stick with the object-array.
260272
pass
261273

262274
return vals
263275

264-
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
276+
def _get_sampler_stats(
277+
self, stat_name: str, sampler_idx: int, burn: int, thin: int
278+
) -> np.ndarray:
265279
"""Get sampler statistics."""
266280
raise NotImplementedError()
267281

@@ -476,23 +490,34 @@ def get_sampler_stats(
476490
combine: bool = True,
477491
chains: Optional[Union[int, Sequence[int]]] = None,
478492
squeeze: bool = True,
479-
):
493+
) -> Union[List[np.ndarray], np.ndarray]:
480494
"""Get sampler statistics from the trace.
481495
496+
Note: This implementation attempts to squeeze object arrays into a consistent dtype,
497+
# which can change their shape in hard-to-predict ways.
498+
# See https://github.com/pymc-devs/pymc/issues/6207
499+
482500
Parameters
483501
----------
484-
stat_name: str
485-
sampler_idx: int or None
486-
burn: int
487-
thin: int
502+
stat_name : str
503+
Name of the stat to fetch.
504+
sampler_idx : int or None
505+
Index of the sampler to get the stat from.
506+
burn : int
507+
Draws to skip from the start.
508+
thin : int
509+
Stepsize for the slice.
510+
combine : bool
511+
If True, results from `chains` will be concatenated.
512+
squeeze : bool
513+
Return a single array element if the resulting list of
514+
values only has one element. If False, the result will
515+
always be a list of arrays, even if `combine` is True.
488516
489517
Returns
490518
-------
491-
If the `sampler_idx` is specified, return the statistic with
492-
the given name in a numpy array. If it is not specified and there
493-
is more than one sampler that provides this statistic, return
494-
a numpy array of shape (m, n), where `m` is the number of
495-
such samplers, and `n` is the number of samples.
519+
stats : np.ndarray
520+
List or ndarray depending on parameters.
496521
"""
497522
if stat_name not in self.stat_names:
498523
raise KeyError("Unknown sampler statistic %s" % stat_name)
@@ -543,7 +568,7 @@ def points(self, chains=None):
543568
return itl.chain.from_iterable(self._straces[chain] for chain in chains)
544569

545570

546-
def _squeeze_cat(results, combine, squeeze):
571+
def _squeeze_cat(results, combine: bool, squeeze: bool):
547572
"""Squeeze and concatenate the results depending on values of
548573
`combine` and `squeeze`."""
549574
if combine:

pymc/backends/ndarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def record(self, point, sampler_stats=None) -> None:
119119
data[key][self.draw_idx] = val
120120
self.draw_idx += 1
121121

122-
def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
122+
def _get_sampler_stats(
123+
self, varname: str, sampler_idx: int, burn: int, thin: int
124+
) -> np.ndarray:
123125
return self._stats[sampler_idx][varname][burn::thin]
124126

125127
def close(self):

pymc/sampling/mcmc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,8 @@ def sample(
580580
# count the number of tune/draw iterations that happened
581581
# ideally via the "tune" statistic, but not all samplers record it!
582582
if "tune" in mtrace.stat_names:
583-
stat = mtrace.get_sampler_stats("tune", chains=0)
584-
# when CompoundStep is used, the stat is 2 dimensional!
585-
if len(stat.shape) == 2:
586-
stat = stat[:, 0]
583+
# Get the tune stat directly from chain 0, sampler 0
584+
stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0)
587585
stat = tuple(stat)
588586
n_tune = stat.count(True)
589587
n_draws = stat.count(False)

0 commit comments

Comments
 (0)