Skip to content

Commit 3ff4e7a

Browse files
Bugfixes to increase robustness against unnamed dims (#6339)
* Extract `ModelGraph._eval` to a function * More robustness against unlabeled `dims` entries Closes #6335
1 parent e0d25c8 commit 3ff4e7a

File tree

4 files changed

+72
-11
lines changed

4 files changed

+72
-11
lines changed

pymc/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,8 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
14981498
This can include several types of variables such basic_RVs, Data, Deterministics,
14991499
and Potentials.
15001500
"""
1501+
if var.name is None:
1502+
raise ValueError("Variable is unnamed.")
15011503
if self.named_vars.tree_contains(var.name):
15021504
raise ValueError(f"Variable name {var.name} already exists.")
15031505

@@ -1507,7 +1509,7 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
15071509
for dim in dims:
15081510
if dim not in self.coords and dim is not None:
15091511
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
1510-
if any(var.name == dim for dim in dims):
1512+
if any(var.name == dim for dim in dims if dim is not None):
15111513
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
15121514
self.named_vars_to_dims[var.name] = dims
15131515

pymc/model_graph.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515

1616
from collections import defaultdict
17-
from typing import Dict, Iterable, List, NewType, Optional, Set
17+
from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set
1818

1919
from aesara import function
2020
from aesara.compile.sharedvalue import SharedVariable
@@ -32,6 +32,17 @@
3232
VarName = NewType("VarName", str)
3333

3434

35+
__all__ = (
36+
"ModelGraph",
37+
"model_to_graphviz",
38+
"model_to_networkx",
39+
)
40+
41+
42+
def fast_eval(var):
43+
return function([], var, mode="FAST_COMPILE")()
44+
45+
3546
class ModelGraph:
3647
def __init__(self, model):
3748
self.model = model
@@ -183,9 +194,6 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
183194
else:
184195
graph.node(var_name.replace(":", "&"), **kwargs)
185196

186-
def _eval(self, var):
187-
return function([], var, mode="FAST_COMPILE")()
188-
189197
def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]:
190198
"""Rough but surprisingly accurate plate detection.
191199
@@ -198,18 +206,32 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
198206
"""
199207
plates = defaultdict(set)
200208

209+
# TODO: Evaluate all RV shapes and dim_length at once.
210+
# This should help to find discrepancies, and
211+
# avoids unncessary function compiles for deetermining labels.
212+
201213
for var_name in self.vars_to_plot(var_names):
202214
v = self.model[var_name]
215+
shape: Sequence[int] = fast_eval(v.shape)
216+
dim_labels = []
203217
if var_name in self.model.named_vars_to_dims:
204-
plate_label = " x ".join(
205-
f"{d} ({self._eval(self.model.dim_lengths[d])})"
206-
for d in self.model.named_vars_to_dims[var_name]
207-
)
218+
# The RV is associated with `dims` information.
219+
for d, dname in enumerate(self.model.named_vars_to_dims[var_name]):
220+
if dname is None:
221+
# Unnamed dimension in a `dims` tuple!
222+
dlen = shape[d]
223+
dname = f"{var_name}_dim{d}"
224+
else:
225+
dlen = fast_eval(self.model.dim_lengths[dname])
226+
dim_labels.append(f"{dname} ({dlen})")
227+
plate_label = " x ".join(dim_labels)
208228
else:
209-
plate_label = " x ".join(map(str, self._eval(v.shape)))
229+
# The RV has no `dims` information.
230+
dim_labels = map(str, shape)
231+
plate_label = " x ".join(map(str, shape))
210232
plates[plate_label].add(var_name)
211233

212-
return plates
234+
return dict(plates)
213235

214236
def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
215237
"""Make graphviz Digraph of PyMC model

pymc/tests/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,30 @@ def test_set_dim_with_coords():
844844
assert pmodel.coords["mdim"] == ("A", "B", "C")
845845

846846

847+
def test_add_named_variable_checks_dim_name():
848+
with pm.Model() as pmodel:
849+
rv = pm.Normal.dist(mu=[1, 2])
850+
851+
# Checks that vars are named
852+
with pytest.raises(ValueError, match="is unnamed"):
853+
pmodel.add_named_variable(rv)
854+
rv.name = "nomnom"
855+
856+
# Coords must be available already
857+
with pytest.raises(ValueError, match="not specified in `coords`"):
858+
pmodel.add_named_variable(rv, dims="nomnom")
859+
pmodel.add_coord("nomnom", [1, 2])
860+
861+
# No name collisions
862+
with pytest.raises(ValueError, match="same name as"):
863+
pmodel.add_named_variable(rv, dims="nomnom")
864+
865+
# This should work (regression test against #6335)
866+
rv2 = rv[:, None]
867+
rv2.name = "yumyum"
868+
pmodel.add_named_variable(rv2, dims=("nomnom", None))
869+
870+
847871
def test_set_data_indirect_resize():
848872
with pm.Model() as pmodel:
849873
pmodel.add_coord("mdim", mutable=True, length=2)

pymc/tests/test_model_graph.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515

1616
import aesara
17+
import aesara.tensor as at
1718
import numpy as np
1819
import pytest
1920

@@ -340,6 +341,18 @@ class TestImputationModel(BaseModelGraphTest):
340341
class TestModelWithDims(BaseModelGraphTest):
341342
model_func = model_with_dims
342343

344+
def test_issue_6335_dims_containing_none(self):
345+
with pm.Model(coords=dict(time=np.arange(5))) as pmodel:
346+
data = at.as_tensor(np.ones((3, 5)))
347+
pm.Deterministic("n", data, dims=(None, "time"))
348+
349+
mg = ModelGraph(pmodel)
350+
plates_actual = mg.get_plates()
351+
plates_expected = {
352+
"n_dim0 (3) x time (5)": {"n"},
353+
}
354+
assert plates_actual == plates_expected
355+
343356

344357
class TestUnnamedObservedNodes(BaseModelGraphTest):
345358
model_func = model_unnamed_observed_node

0 commit comments

Comments
 (0)