Skip to content

Add option to save model graph to an image #7158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,13 @@ def debug_parameters(rv):
print_("You can set `verbose=True` for more details")

def to_graphviz(
self, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
self,
*,
var_names: Optional[Iterable[VarName]] = None,
formatting: str = "plain",
save: Optional[str] = None,
figsize: Optional[tuple[int, int]] = None,
dpi: int = 300,
):
"""Produce a graphviz Digraph from a PyMC model.

Expand All @@ -1851,6 +1857,14 @@ def to_graphviz(
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
formatting : str, optional
one of { "plain" }
save : str, optional
If provided, an image of the graph will be saved to this location. The format is inferred from
the file extension.
figsize : tuple[int, int], optional
Width and height of the figure in inches. If not provided, uses the default figure size. It only affect
the size of the saved figure.
dpi : int, optional
Dots per inch. It only affects the resolution of the saved figure. The default is 300.

Examples
--------
Expand All @@ -1877,7 +1891,14 @@ def to_graphviz(

schools.to_graphviz()
"""
return model_to_graphviz(model=self, var_names=var_names, formatting=formatting)
return model_to_graphviz(
model=self,
var_names=var_names,
formatting=formatting,
save=save,
figsize=figsize,
dpi=dpi,
)


# this is really disgusting, but it breaks a self-loop: I can't pass Model
Expand Down
41 changes: 39 additions & 2 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from collections import defaultdict
from collections.abc import Iterable, Sequence
from os import path
from typing import Optional

from pytensor import function
Expand Down Expand Up @@ -238,7 +239,14 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> dict[str,

return dict(plates)

def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
def make_graph(
self,
var_names: Optional[Iterable[VarName]] = None,
formatting: str = "plain",
save=None,
figsize=None,
dpi=300,
):
"""Make graphviz Digraph of PyMC model

Returns
Expand Down Expand Up @@ -271,6 +279,18 @@ def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting:
for parent in parents:
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))

if save is not None:
width, height = (None, None) if figsize is None else figsize
base, ext = path.splitext(save)
if ext:
ext = ext.replace(".", "")
else:
ext = "png"
graph_c = graph.copy()
graph_c.graph_attr.update(size=f"{width},{height}!")
graph_c.graph_attr.update(dpi=str(dpi))
graph_c.render(filename=base, format=ext, cleanup=True)

return graph

def make_networkx(
Expand Down Expand Up @@ -399,6 +419,9 @@ def model_to_graphviz(
*,
var_names: Optional[Iterable[VarName]] = None,
formatting: str = "plain",
save: Optional[str] = None,
figsize: Optional[tuple[int, int]] = None,
dpi: int = 300,
):
"""Produce a graphviz Digraph from a PyMC model.

Expand All @@ -418,6 +441,14 @@ def model_to_graphviz(
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
formatting : str, optional
one of { "plain" }
save : str, optional
If provided, an image of the graph will be saved to this location. The format is inferred from
the file extension.
figsize : tuple[int, int], optional
Width and height of the figure in inches. If not provided, uses the default figure size. It only affect
the size of the saved figure.
dpi : int, optional
Dots per inch. It only affects the resolution of the saved figure. The default is 300.

Examples
--------
Expand Down Expand Up @@ -453,4 +484,10 @@ def model_to_graphviz(
stacklevel=2,
)
model = pm.modelcontext(model)
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
return ModelGraph(model).make_graph(
var_names=var_names,
formatting=formatting,
save=save,
figsize=figsize,
dpi=dpi,
)
16 changes: 13 additions & 3 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,10 +1690,20 @@ def school_model(J: int) -> pm.Model:
@pytest.mark.parametrize(
argnames="var_names", argvalues=[None, ["mu", "tau"]], ids=["all", "subset"]
)
def test_graphviz_call_function(self, var_names) -> None:
@pytest.mark.parametrize(
argnames="filenames",
argvalues=["model.png", "model"],
ids=["ext", "no_ext"],
)
def test_graphviz_call_function(self, var_names, filenames) -> None:
model = self.school_model(J=8)
with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz:
model.to_graphviz(var_names=var_names)
model.to_graphviz(var_names=var_names, save=filenames)
mock_model_to_graphviz.assert_called_once_with(
model=model, var_names=var_names, formatting="plain"
model=model,
var_names=var_names,
formatting="plain",
save=filenames,
figsize=None,
dpi=300,
)
9 changes: 8 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import itertools as it
import re

from os import path

import cloudpickle
import numpy as np
import pytensor
Expand Down Expand Up @@ -268,7 +270,7 @@ def test_set_data_to_non_data_container_variables(self):
error.match("The variable `beta` must be a `SharedVariable`")

@pytest.mark.xfail(reason="Depends on ModelGraph")
def test_model_to_graphviz_for_model_with_data_container(self):
def test_model_to_graphviz_for_model_with_data_container(self, tmp_path):
with pm.Model() as model:
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
y = pm.MutableData("y", [1.0, 2.0, 3.0])
Expand Down Expand Up @@ -307,6 +309,11 @@ def test_model_to_graphviz_for_model_with_data_container(self):
for expected in expected_substrings:
assert expected in g.source

pm.model_to_graphviz(model, save=tmp_path / "model.png")
assert path.exists(tmp_path / "model.png")
pm.model_to_graphviz(model, save=tmp_path / "a_model", dpi=100)
assert path.exists(tmp_path / "a_model.png")

def test_explicit_coords(self, seeded_test):
N_rows = 5
N_cols = 7
Expand Down