Skip to content

Commit 902b1ec

Browse files
committed
Enforce dims to be strings
1 parent f4986a4 commit 902b1ec

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

pymc/backends/arviz.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,9 @@ def __init__(
214214
}
215215

216216
self.dims = {} if dims is None else dims
217-
model_dims = {
218-
var_name: [dim for dim in dims if dim is not None]
219-
for var_name, dims in self.model.named_vars_to_dims.items()
220-
}
217+
model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()}
221218
self.dims = {**model_dims, **self.dims}
219+
222220
if sample_dims is None:
223221
sample_dims = ["chain", "draw"]
224222
self.sample_dims = sample_dims

pymc/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,8 @@ def register_rv(
13171317
# the length of the corresponding RV dimension.
13181318
if dims is not None:
13191319
for d, dname in enumerate(dims):
1320+
if not isinstance(dname, str):
1321+
raise TypeError(f"Dims must be string. Got {dname} of type {type(dname)}")
13201322
if dname not in self.dim_lengths:
13211323
self.add_coord(dname, values=None, length=rv_var.shape[d])
13221324

pymc/tests/test_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,23 @@ def test_add_named_variable_checks_dim_name():
871871
pmodel.add_named_variable(rv2, dims=("nomnom", None))
872872

873873

874+
def test_dims_type_check():
875+
with pm.Model(coords={"a": range(5)}) as m:
876+
with pytest.raises(TypeError, match="Dims must be string"):
877+
x = pm.Normal("x", shape=(10, 5), dims=(None, "a"))
878+
879+
880+
def test_none_coords_autonumbering():
881+
with pm.Model() as m:
882+
m.add_coord(name="a", values=None, length=3)
883+
m.add_coord(name="b", values=range(5))
884+
x = pm.Normal("x", dims=("a", "b"))
885+
prior = pm.sample_prior_predictive(samples=2).prior
886+
assert prior["x"].shape == (1, 2, 3, 5)
887+
assert list(prior.coords["a"].values) == list(range(3))
888+
assert list(prior.coords["b"].values) == list(range(5))
889+
890+
874891
def test_set_data_indirect_resize():
875892
with pm.Model() as pmodel:
876893
pmodel.add_coord("mdim", mutable=True, length=2)

0 commit comments

Comments
 (0)