Skip to content

Commit a1dce17

Browse files
ricardoV94michaelosthege
authored andcommitted
Use dataset.sizes instead of dataset.dims
A FutureWarning is issued in 'xarray==2023.12.0'
1 parent 4de076d commit a1dce17

File tree

4 files changed

+22
-22
lines changed

4 files changed

+22
-22
lines changed

pymc/backends/arviz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ def predictions_to_inference_data(
582582
)
583583
if hasattr(idata_orig, "posterior"):
584584
assert idata_orig is not None
585-
converter.nchains = idata_orig["posterior"].dims["chain"]
586-
converter.ndraws = idata_orig["posterior"].dims["draw"]
585+
converter.nchains = idata_orig["posterior"].sizes["chain"]
586+
converter.ndraws = idata_orig["posterior"].sizes["draw"]
587587
else:
588588
aelem = next(iter(predictions.values()))
589589
converter.nchains, converter.ndraws = aelem.shape[:2]

tests/backends/test_arviz.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
156156
}
157157
fails = check_multiple_attrs(test_dict, inference_data)
158158
assert not fails
159-
chains = inference_data.posterior.dims["chain"]
160-
draws = inference_data.posterior.dims["draw"]
159+
chains = inference_data.posterior.sizes["chain"]
160+
draws = inference_data.posterior.sizes["draw"]
161161
obs = inference_data.observed_data["obs"]
162162
assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape
163163

@@ -177,7 +177,7 @@ def test_predictions_to_idata(self, data, eight_schools_params):
177177
assert not fails
178178
for key, ivalues in inference_data.predictions.items():
179179
assert (
180-
len(ivalues["chain"]) == inference_data.posterior.dims["chain"]
180+
len(ivalues["chain"]) == inference_data.posterior.sizes["chain"]
181181
) # same chains as in posterior
182182

183183
# check adding in place
@@ -188,7 +188,7 @@ def test_predictions_to_idata(self, data, eight_schools_params):
188188
assert not fails
189189
for key, ivalues in inference_data.predictions.items():
190190
assert (
191-
len(ivalues["chain"]) == inference_data.posterior.dims["chain"]
191+
len(ivalues["chain"]) == inference_data.posterior.sizes["chain"]
192192
) # same chains as in posterior
193193

194194
def test_predictions_to_idata_new(self, data, eight_schools_params):
@@ -241,10 +241,10 @@ def test_posterior_predictive_thinned(self, data):
241241
}
242242
fails = check_multiple_attrs(test_dict, idata)
243243
assert not fails
244-
assert idata.posterior.dims["chain"] == 2
245-
assert idata.posterior.dims["draw"] == draws
246-
assert idata.posterior_predictive.dims["chain"] == 2
247-
assert idata.posterior_predictive.dims["draw"] == draws / thin_by
244+
assert idata.posterior.sizes["chain"] == 2
245+
assert idata.posterior.sizes["draw"] == draws
246+
assert idata.posterior_predictive.sizes["chain"] == 2
247+
assert idata.posterior_predictive.sizes["draw"] == draws / thin_by
248248
assert np.allclose(idata.posterior["draw"], np.arange(draws))
249249
assert np.allclose(idata.posterior_predictive["draw"], np.arange(draws, step=thin_by))
250250

@@ -723,11 +723,11 @@ def test_save_warmup(self, save_warmup, chains, tune, draws):
723723
fails = check_multiple_attrs(test_dict, idata)
724724
assert not fails
725725
if hasattr(idata, "posterior"):
726-
assert idata.posterior.dims["chain"] == chains
727-
assert idata.posterior.dims["draw"] == draws
726+
assert idata.posterior.sizes["chain"] == chains
727+
assert idata.posterior.sizes["draw"] == draws
728728
if hasattr(idata, "warmup_posterior"):
729-
assert idata.warmup_posterior.dims["chain"] == chains
730-
assert idata.warmup_posterior.dims["draw"] == tune
729+
assert idata.warmup_posterior.sizes["chain"] == chains
730+
assert idata.warmup_posterior.sizes["draw"] == tune
731731

732732
def test_save_warmup_issue_1208_after_3_9(self):
733733
with pm.Model():
@@ -757,8 +757,8 @@ def test_save_warmup_issue_1208_after_3_9(self):
757757
}
758758
fails = check_multiple_attrs(test_dict, idata)
759759
assert not fails
760-
assert idata.posterior.dims["chain"] == 2
761-
assert idata.posterior.dims["draw"] == 200
760+
assert idata.posterior.sizes["chain"] == 2
761+
assert idata.posterior.sizes["draw"] == 200
762762

763763
# manually sliced trace triggers the same warning as <=3.8
764764
with pytest.warns(UserWarning, match="Warmup samples"):
@@ -771,5 +771,5 @@ def test_save_warmup_issue_1208_after_3_9(self):
771771
}
772772
fails = check_multiple_attrs(test_dict, idata)
773773
assert not fails
774-
assert idata.posterior.dims["chain"] == 2
775-
assert idata.posterior.dims["draw"] == 30
774+
assert idata.posterior.sizes["chain"] == 2
775+
assert idata.posterior.sizes["draw"] == 30

tests/sampling/test_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ def test_idata_contains_stats(sampler_name: str):
433433

434434
stats = idata.get("sample_stats")
435435
assert stats is not None
436-
n_chains = stats.dims["chain"]
437-
n_draws = stats.dims["draw"]
436+
n_chains = stats.sizes["chain"]
437+
n_draws = stats.sizes["draw"]
438438

439439
# Stats vars expected for both samplers
440440
expected_stat_vars = {

tests/smc/test_smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def test_return_datatype(self, chains):
222222

223223
assert isinstance(idata, InferenceData)
224224
assert "sample_stats" in idata
225-
assert idata.posterior.dims["chain"] == chains
226-
assert idata.posterior.dims["draw"] == draws
225+
assert idata.posterior.sizes["chain"] == chains
226+
assert idata.posterior.sizes["draw"] == draws
227227

228228
assert isinstance(mt, MultiTrace)
229229
assert mt.nchains == chains

0 commit comments

Comments
 (0)