Skip to content

Commit c3e0d86

Browse files
authored
Add option to save model graph to an image (#7158)
1 parent 3a304d6 commit c3e0d86

File tree

4 files changed

+83
-8
lines changed

4 files changed

+83
-8
lines changed

pymc/model/core.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,13 @@ def debug_parameters(rv):
18331833
print_("You can set `verbose=True` for more details")
18341834

18351835
def to_graphviz(
1836-
self, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
1836+
self,
1837+
*,
1838+
var_names: Optional[Iterable[VarName]] = None,
1839+
formatting: str = "plain",
1840+
save: Optional[str] = None,
1841+
figsize: Optional[tuple[int, int]] = None,
1842+
dpi: int = 300,
18371843
):
18381844
"""Produce a graphviz Digraph from a PyMC model.
18391845
@@ -1851,6 +1857,14 @@ def to_graphviz(
18511857
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
18521858
formatting : str, optional
18531859
one of { "plain" }
1860+
save : str, optional
1861+
If provided, an image of the graph will be saved to this location. The format is inferred from
1862+
the file extension.
1863+
figsize : tuple[int, int], optional
1864+
Width and height of the figure in inches. If not provided, uses the default figure size. It only affect
1865+
the size of the saved figure.
1866+
dpi : int, optional
1867+
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
18541868
18551869
Examples
18561870
--------
@@ -1877,7 +1891,14 @@ def to_graphviz(
18771891
18781892
schools.to_graphviz()
18791893
"""
1880-
return model_to_graphviz(model=self, var_names=var_names, formatting=formatting)
1894+
return model_to_graphviz(
1895+
model=self,
1896+
var_names=var_names,
1897+
formatting=formatting,
1898+
save=save,
1899+
figsize=figsize,
1900+
dpi=dpi,
1901+
)
18811902

18821903

18831904
# this is really disgusting, but it breaks a self-loop: I can't pass Model

pymc/model_graph.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from collections import defaultdict
1717
from collections.abc import Iterable, Sequence
18+
from os import path
1819
from typing import Optional
1920

2021
from pytensor import function
@@ -238,7 +239,14 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> dict[str,
238239

239240
return dict(plates)
240241

241-
def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
242+
def make_graph(
243+
self,
244+
var_names: Optional[Iterable[VarName]] = None,
245+
formatting: str = "plain",
246+
save=None,
247+
figsize=None,
248+
dpi=300,
249+
):
242250
"""Make graphviz Digraph of PyMC model
243251
244252
Returns
@@ -271,6 +279,18 @@ def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting:
271279
for parent in parents:
272280
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))
273281

282+
if save is not None:
283+
width, height = (None, None) if figsize is None else figsize
284+
base, ext = path.splitext(save)
285+
if ext:
286+
ext = ext.replace(".", "")
287+
else:
288+
ext = "png"
289+
graph_c = graph.copy()
290+
graph_c.graph_attr.update(size=f"{width},{height}!")
291+
graph_c.graph_attr.update(dpi=str(dpi))
292+
graph_c.render(filename=base, format=ext, cleanup=True)
293+
274294
return graph
275295

276296
def make_networkx(
@@ -399,6 +419,9 @@ def model_to_graphviz(
399419
*,
400420
var_names: Optional[Iterable[VarName]] = None,
401421
formatting: str = "plain",
422+
save: Optional[str] = None,
423+
figsize: Optional[tuple[int, int]] = None,
424+
dpi: int = 300,
402425
):
403426
"""Produce a graphviz Digraph from a PyMC model.
404427
@@ -418,6 +441,14 @@ def model_to_graphviz(
418441
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
419442
formatting : str, optional
420443
one of { "plain" }
444+
save : str, optional
445+
If provided, an image of the graph will be saved to this location. The format is inferred from
446+
the file extension.
447+
figsize : tuple[int, int], optional
448+
Width and height of the figure in inches. If not provided, uses the default figure size. It only affect
449+
the size of the saved figure.
450+
dpi : int, optional
451+
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
421452
422453
Examples
423454
--------
@@ -453,4 +484,10 @@ def model_to_graphviz(
453484
stacklevel=2,
454485
)
455486
model = pm.modelcontext(model)
456-
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
487+
return ModelGraph(model).make_graph(
488+
var_names=var_names,
489+
formatting=formatting,
490+
save=save,
491+
figsize=figsize,
492+
dpi=dpi,
493+
)

tests/model/test_core.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,10 +1690,20 @@ def school_model(J: int) -> pm.Model:
16901690
@pytest.mark.parametrize(
16911691
argnames="var_names", argvalues=[None, ["mu", "tau"]], ids=["all", "subset"]
16921692
)
1693-
def test_graphviz_call_function(self, var_names) -> None:
1693+
@pytest.mark.parametrize(
1694+
argnames="filenames",
1695+
argvalues=["model.png", "model"],
1696+
ids=["ext", "no_ext"],
1697+
)
1698+
def test_graphviz_call_function(self, var_names, filenames) -> None:
16941699
model = self.school_model(J=8)
16951700
with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz:
1696-
model.to_graphviz(var_names=var_names)
1701+
model.to_graphviz(var_names=var_names, save=filenames)
16971702
mock_model_to_graphviz.assert_called_once_with(
1698-
model=model, var_names=var_names, formatting="plain"
1703+
model=model,
1704+
var_names=var_names,
1705+
formatting="plain",
1706+
save=filenames,
1707+
figsize=None,
1708+
dpi=300,
16991709
)

tests/test_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import itertools as it
1717
import re
1818

19+
from os import path
20+
1921
import cloudpickle
2022
import numpy as np
2123
import pytensor
@@ -268,7 +270,7 @@ def test_set_data_to_non_data_container_variables(self):
268270
error.match("The variable `beta` must be a `SharedVariable`")
269271

270272
@pytest.mark.xfail(reason="Depends on ModelGraph")
271-
def test_model_to_graphviz_for_model_with_data_container(self):
273+
def test_model_to_graphviz_for_model_with_data_container(self, tmp_path):
272274
with pm.Model() as model:
273275
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
274276
y = pm.MutableData("y", [1.0, 2.0, 3.0])
@@ -307,6 +309,11 @@ def test_model_to_graphviz_for_model_with_data_container(self):
307309
for expected in expected_substrings:
308310
assert expected in g.source
309311

312+
pm.model_to_graphviz(model, save=tmp_path / "model.png")
313+
assert path.exists(tmp_path / "model.png")
314+
pm.model_to_graphviz(model, save=tmp_path / "a_model", dpi=100)
315+
assert path.exists(tmp_path / "a_model.png")
316+
310317
def test_explicit_coords(self, seeded_test):
311318
N_rows = 5
312319
N_cols = 7

0 commit comments

Comments
 (0)