diff --git a/mcbackend/__init__.py b/mcbackend/__init__.py index a044c75..74c0718 100644 --- a/mcbackend/__init__.py +++ b/mcbackend/__init__.py @@ -20,4 +20,4 @@ pass -__version__ = "0.2.5" +__version__ = "0.2.6" diff --git a/mcbackend/core.py b/mcbackend/core.py index b962437..bb1b118 100644 --- a/mcbackend/core.py +++ b/mcbackend/core.py @@ -9,6 +9,7 @@ from .meta import ChainMeta, RunMeta, Variable from .npproto.utils import ndarray_to_numpy +from .utils import as_array_from_ragged InferenceData = TypeVar("InferenceData") try: @@ -252,7 +253,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> warmup_sample_stats[svar.name].append(stats[tune]) sample_stats[svar.name].append(stats[~tune]) - kwargs.setdefault("save_warmup", True) + if not equalize_chain_lengths: + # Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically + warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()} + warmup_sample_stats = { + k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items() + } + posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()} + sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()} + idata = from_dict( warmup_posterior=warmup_posterior, warmup_sample_stats=warmup_sample_stats, @@ -263,6 +272,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> attrs=self.meta.attributes, constant_data=self.constant_data, observed_data=self.observed_data, + save_warmup=True, **kwargs, ) return idata diff --git a/mcbackend/test_utils.py b/mcbackend/test_utils.py index 13a90b2..2704785 100644 --- a/mcbackend/test_utils.py +++ b/mcbackend/test_utils.py @@ -10,6 +10,7 @@ import pytest import mcbackend +from mcbackend import utils as mutils from mcbackend.meta import ChainMeta, DataVariable, RunMeta, Variable from mcbackend.npproto import utils @@ -407,3 +408,15 @@ def test__big_variables(self): speed = self.measure_big_variables() assert speed.draws_per_second > 500 or speed.mib_per_second > 5 pass + + +def test_as_array_from_ragged(): + even = mutils.as_array_from_ragged( + [ + numpy.ones(2), + numpy.ones(3), + ] + ) + assert isinstance(even, numpy.ndarray) + assert even.dtype == numpy.dtype(object) + pass diff --git a/mcbackend/utils.py b/mcbackend/utils.py new file mode 100644 index 0000000..5af8126 --- /dev/null +++ b/mcbackend/utils.py @@ -0,0 +1,11 @@ +"""Contains helper functions that are independent of McBackend components.""" +from typing import Sequence + +import numpy as np + + +def as_array_from_ragged(arrs: Sequence[np.ndarray]) -> np.ndarray: + shapes = {np.shape(arr) for arr in arrs} + if len(shapes) > 1: + return np.array(arrs, dtype=object) + return np.array(arrs)