@@ -58,7 +58,7 @@ class IBaseTrace(ABC, Sized):
58
58
varnames : List [str ]
59
59
"""Names of tracked variables."""
60
60
61
- sampler_vars : List [Dict [str , type ]]
61
+ sampler_vars : List [Dict [str , Union [ type , np . dtype ] ]]
62
62
"""Sampler stats for each sampler."""
63
63
64
64
def __len__ (self ):
@@ -79,23 +79,27 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
79
79
"""
80
80
raise NotImplementedError ()
81
81
82
- def get_sampler_stats (self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1 ):
82
+ def get_sampler_stats (
83
+ self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1
84
+ ) -> np .ndarray :
83
85
"""Get sampler statistics from the trace.
84
86
85
87
Parameters
86
88
----------
87
- stat_name: str
88
- sampler_idx: int or None
89
- burn: int
90
- thin: int
89
+ stat_name : str
90
+ Name of the stat to fetch.
91
+ sampler_idx : int or None
92
+ Index of the sampler to get the stat from.
93
+ burn : int
94
+ Draws to skip from the start.
95
+ thin : int
96
+ Stepsize for the slice.
91
97
92
98
Returns
93
99
-------
94
- If the `sampler_idx` is specified, return the statistic with
95
- the given name in a numpy array. If it is not specified and there
96
- is more than one sampler that provides this statistic, return
97
- a numpy array of shape (m, n), where `m` is the number of
98
- such samplers, and `n` is the number of samples.
100
+ stats : np.ndarray
101
+ If `sampler_idx` was specified, the shape should be `(draws, samplers)`.
102
+ Otherwise, the shape should be `(draws,)`.
99
103
"""
100
104
raise NotImplementedError ()
101
105
@@ -220,23 +224,31 @@ def __getitem__(self, idx):
220
224
except (ValueError , TypeError ): # Passed variable or variable name.
221
225
raise ValueError ("Can only index with slice or integer" )
222
226
223
- def get_sampler_stats (self , stat_name , sampler_idx = None , burn = 0 , thin = 1 ):
227
+ def get_sampler_stats (
228
+ self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1
229
+ ) -> np .ndarray :
224
230
"""Get sampler statistics from the trace.
225
231
232
+ Note: This implementation attempts to squeeze object arrays into a consistent dtype,
233
+ # which can change their shape in hard-to-predict ways.
234
+ # See https://github.com/pymc-devs/pymc/issues/6207
235
+
226
236
Parameters
227
237
----------
228
- stat_name: str
229
- sampler_idx: int or None
230
- burn: int
231
- thin: int
238
+ stat_name : str
239
+ Name of the stat to fetch.
240
+ sampler_idx : int or None
241
+ Index of the sampler to get the stat from.
242
+ burn : int
243
+ Draws to skip from the start.
244
+ thin : int
245
+ Stepsize for the slice.
232
246
233
247
Returns
234
248
-------
235
- If the `sampler_idx` is specified, return the statistic with
236
- the given name in a numpy array. If it is not specified and there
237
- is more than one sampler that provides this statistic, return
238
- a numpy array of shape (m, n), where `m` is the number of
239
- such samplers, and `n` is the number of samples.
249
+ stats : np.ndarray
250
+ If `sampler_idx` was specified, the shape should be `(draws, samplers)`.
251
+ Otherwise, the shape should be `(draws,)`.
240
252
"""
241
253
if sampler_idx is not None :
242
254
return self ._get_sampler_stats (stat_name , sampler_idx , burn , thin )
@@ -254,14 +266,16 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
254
266
255
267
if vals .dtype == np .dtype (object ):
256
268
try :
257
- vals = np .vstack (vals )
269
+ vals = np .vstack (list ( vals ) )
258
270
except ValueError :
259
271
# Most likely due to non-identical shapes. Just stick with the object-array.
260
272
pass
261
273
262
274
return vals
263
275
264
- def _get_sampler_stats (self , stat_name , sampler_idx , burn , thin ):
276
+ def _get_sampler_stats (
277
+ self , stat_name : str , sampler_idx : int , burn : int , thin : int
278
+ ) -> np .ndarray :
265
279
"""Get sampler statistics."""
266
280
raise NotImplementedError ()
267
281
@@ -476,23 +490,34 @@ def get_sampler_stats(
476
490
combine : bool = True ,
477
491
chains : Optional [Union [int , Sequence [int ]]] = None ,
478
492
squeeze : bool = True ,
479
- ):
493
+ ) -> Union [ List [ np . ndarray ], np . ndarray ] :
480
494
"""Get sampler statistics from the trace.
481
495
496
+ Note: This implementation attempts to squeeze object arrays into a consistent dtype,
497
+ # which can change their shape in hard-to-predict ways.
498
+ # See https://github.com/pymc-devs/pymc/issues/6207
499
+
482
500
Parameters
483
501
----------
484
- stat_name: str
485
- sampler_idx: int or None
486
- burn: int
487
- thin: int
502
+ stat_name : str
503
+ Name of the stat to fetch.
504
+ sampler_idx : int or None
505
+ Index of the sampler to get the stat from.
506
+ burn : int
507
+ Draws to skip from the start.
508
+ thin : int
509
+ Stepsize for the slice.
510
+ combine : bool
511
+ If True, results from `chains` will be concatenated.
512
+ squeeze : bool
513
+ Return a single array element if the resulting list of
514
+ values only has one element. If False, the result will
515
+ always be a list of arrays, even if `combine` is True.
488
516
489
517
Returns
490
518
-------
491
- If the `sampler_idx` is specified, return the statistic with
492
- the given name in a numpy array. If it is not specified and there
493
- is more than one sampler that provides this statistic, return
494
- a numpy array of shape (m, n), where `m` is the number of
495
- such samplers, and `n` is the number of samples.
519
+ stats : np.ndarray
520
+ List or ndarray depending on parameters.
496
521
"""
497
522
if stat_name not in self .stat_names :
498
523
raise KeyError ("Unknown sampler statistic %s" % stat_name )
@@ -543,7 +568,7 @@ def points(self, chains=None):
543
568
return itl .chain .from_iterable (self ._straces [chain ] for chain in chains )
544
569
545
570
546
- def _squeeze_cat (results , combine , squeeze ):
571
+ def _squeeze_cat (results , combine : bool , squeeze : bool ):
547
572
"""Squeeze and concatenate the results depending on values of
548
573
`combine` and `squeeze`."""
549
574
if combine :
0 commit comments