Skip to content

Quick fix for pretty representation of symbolic distributions #5847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
resize_from_dims,
resize_from_observed,
)
from pymc.printing import str_for_dist
from pymc.printing import str_for_dist, str_for_symbolic_dist
from pymc.util import UNSET
from pymc.vartypes import string_types

Expand Down Expand Up @@ -483,15 +483,11 @@ def __new__(
transform=transform,
initval=initval,
)

# TODO: Refactor this
# add in pretty-printing support
rv_out.str_repr = lambda *args, **kwargs: name
rv_out._repr_latex_ = f"\\text{name}"
# rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
# rv_out._repr_latex_ = types.MethodType(
# functools.partial(str_for_dist, formatting="latex"), rv_out
# )
rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
)

return rv_out

Expand Down
12 changes: 12 additions & 0 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params:
return rf"{print_name} ~ {dist_name}"


def str_for_symbolic_dist(
rv: TensorVariable, formatting: str = "plain", include_params: bool = True
) -> str:
"""Make a human-readable string representation of a SymbolicDistribution in a model,
either LaTeX or plain, optionally with distribution parameter values included."""

if "latex" in formatting:
return rf"$\text{{{rv.name}}} \sim \text{{{rv.owner.op}}}$"
else:
return rf"{rv.name} ~ {rv.owner.op}"


def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
"""Make a human-readable string representation of Model, listing all random variables
and their distributions, optionally including parameter values."""
Expand Down
135 changes: 1 addition & 134 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def polyagamma_cdf(*args, **kwargs):
)
from pymc.distributions.shape_utils import to_tuple
from pymc.math import kronecker
from pymc.model import Deterministic, Model, Point, Potential
from pymc.model import Model, Point
from pymc.tests.helpers import select_by_precision
from pymc.vartypes import continuous_types, discrete_types

Expand Down Expand Up @@ -2886,139 +2886,6 @@ def test_lower_bounded_broadcasted(self):
assert upper_interval is None


class TestStrAndLatexRepr:
def setup_class(self):
# True parameter values
alpha, sigma = 1, 1
beta = [1, 2.5]

# Size of dataset
size = 100

# Predictor variable
X = np.random.normal(size=(size, 2)).dot(np.array([[1, 0], [0, 0.2]]))

# Simulate outcome variable
Y = alpha + X.dot(beta) + np.random.randn(size) * sigma
with Model() as self.model:
# TODO: some variables commented out here as they're not working properly
# in v4 yet (9-jul-2021), so doesn't make sense to test str/latex for them

# Priors for unknown model parameters
alpha = Normal("alpha", mu=0, sigma=10)
b = Normal("beta", mu=0, sigma=10, size=(2,), observed=beta)
sigma = HalfNormal("sigma", sigma=1)

# Test Cholesky parameterization
Z = MvNormal("Z", mu=np.zeros(2), chol=np.eye(2), size=(2,))

# NegativeBinomial representations to test issue 4186
# nb1 = pm.NegativeBinomial(
# "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1)
# )
nb2 = pm.NegativeBinomial("nb_with_p_n", p=pm.Uniform("nbp"), n=10)

# Expected value of outcome
mu = Deterministic("mu", floatX(alpha + at.dot(X, b)))

# add a bounded variable as well
# bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)

# KroneckerNormal
n, m = 3, 4
covs = [np.eye(n), np.eye(m)]
kron_normal = KroneckerNormal("kron_normal", mu=np.zeros(n * m), covs=covs, size=n * m)

# MatrixNormal
# matrix_normal = MatrixNormal(
# "mat_normal",
# mu=np.random.normal(size=n),
# rowcov=np.eye(n),
# colchol=np.linalg.cholesky(np.eye(n)),
# size=(n, n),
# )

# DirichletMultinomial
dm = DirichletMultinomial("dm", n=5, a=[1, 1, 1], size=(2, 3))

# Likelihood (sampling distribution) of observations
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)

# add a potential as well
pot = Potential("pot", mu**2)

self.distributions = [alpha, sigma, mu, b, Z, nb2, Y_obs, pot]
self.deterministics_or_potentials = [mu, pot]
# tuples of (formatting, include_params
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
self.expected = {
("plain", True): [
r"alpha ~ N(0, 10)",
r"sigma ~ N**+(0, 1)",
r"mu ~ Deterministic(f(beta, alpha))",
r"beta ~ N(0, 10)",
r"Z ~ N(f(), f())",
r"nb_with_p_n ~ NB(10, nbp)",
r"Y_obs ~ N(mu, sigma)",
r"pot ~ Potential(f(beta, alpha))",
],
("plain", False): [
r"alpha ~ N",
r"sigma ~ N**+",
r"mu ~ Deterministic",
r"beta ~ N",
r"Z ~ N",
r"nb_with_p_n ~ NB",
r"Y_obs ~ N",
r"pot ~ Potential",
],
("latex", True): [
r"$\text{alpha} \sim \operatorname{N}(0,~10)$",
r"$\text{sigma} \sim \operatorname{N^{+}}(0,~1)$",
r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$",
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
],
("latex", False): [
r"$\text{alpha} \sim \operatorname{N}$",
r"$\text{sigma} \sim \operatorname{N^{+}}$",
r"$\text{mu} \sim \operatorname{Deterministic}$",
r"$\text{beta} \sim \operatorname{N}$",
r"$\text{Z} \sim \operatorname{N}$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}$",
r"$\text{Y_obs} \sim \operatorname{N}$",
r"$\text{pot} \sim \operatorname{Potential}$",
],
}

def test__repr_latex_(self):
for distribution, tex in zip(self.distributions, self.expected[("latex", True)]):
assert distribution._repr_latex_() == tex

model_tex = self.model._repr_latex_()

# make sure each variable is in the model
for tex in self.expected[("latex", True)]:
for segment in tex.strip("$").split(r"\sim"):
assert segment in model_tex

def test_str_repr(self):
for str_format in self.formats:
for dist, text in zip(self.distributions, self.expected[str_format]):
assert dist.str_repr(*str_format) == text

model_text = self.model.str_repr(*str_format)
for text in self.expected[str_format]:
if str_format[0] == "latex":
for segment in text.strip("$").split(r"\sim"):
assert segment in model_text
else:
assert text in model_text


def test_discrete_trafo():
with Model():
with pytest.raises(ValueError) as err:
Expand Down
156 changes: 156 additions & 0 deletions pymc/tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import numpy as np

from pymc.aesaraf import floatX
from pymc.distributions import (
DirichletMultinomial,
HalfNormal,
KroneckerNormal,
MvNormal,
NegativeBinomial,
Normal,
Uniform,
ZeroInflatedPoisson,
)
from pymc.math import dot
from pymc.model import Deterministic, Model, Potential


# TODO: This test is a bit too monolithic
class TestStrAndLatexRepr:
def setup_class(self):
# True parameter values
alpha, sigma = 1, 1
beta = [1, 2.5]

# Size of dataset
size = 100

# Predictor variable
X = np.random.normal(size=(size, 2)).dot(np.array([[1, 0], [0, 0.2]]))

# Simulate outcome variable
Y = alpha + X.dot(beta) + np.random.randn(size) * sigma
with Model() as self.model:
# TODO: some variables commented out here as they're not working properly
# in v4 yet (9-jul-2021), so doesn't make sense to test str/latex for them

# Priors for unknown model parameters
alpha = Normal("alpha", mu=0, sigma=10)
b = Normal("beta", mu=0, sigma=10, size=(2,), observed=beta)
sigma = HalfNormal("sigma", sigma=1)

# Test Cholesky parameterization
Z = MvNormal("Z", mu=np.zeros(2), chol=np.eye(2), size=(2,))

# NegativeBinomial representations to test issue 4186
# nb1 = pm.NegativeBinomial(
# "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1)
# )
nb2 = NegativeBinomial("nb_with_p_n", p=Uniform("nbp"), n=10)

# Symbolic distribution
zip = ZeroInflatedPoisson("zip", 0.5, 5)

# Expected value of outcome
mu = Deterministic("mu", floatX(alpha + dot(X, b)))

# add a bounded variable as well
# bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)

# KroneckerNormal
n, m = 3, 4
covs = [np.eye(n), np.eye(m)]
kron_normal = KroneckerNormal("kron_normal", mu=np.zeros(n * m), covs=covs, size=n * m)

# MatrixNormal
# matrix_normal = MatrixNormal(
# "mat_normal",
# mu=np.random.normal(size=n),
# rowcov=np.eye(n),
# colchol=np.linalg.cholesky(np.eye(n)),
# size=(n, n),
# )

# DirichletMultinomial
dm = DirichletMultinomial("dm", n=5, a=[1, 1, 1], size=(2, 3))

# Likelihood (sampling distribution) of observations
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)

# add a potential as well
pot = Potential("pot", mu**2)

self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, Y_obs, pot]
self.deterministics_or_potentials = [mu, pot]
# tuples of (formatting, include_params
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
self.expected = {
("plain", True): [
r"alpha ~ N(0, 10)",
r"sigma ~ N**+(0, 1)",
r"mu ~ Deterministic(f(beta, alpha))",
r"beta ~ N(0, 10)",
r"Z ~ N(f(), f())",
r"nb_with_p_n ~ NB(10, nbp)",
r"zip ~ MarginalMixtureRV{inline=False}",
r"Y_obs ~ N(mu, sigma)",
r"pot ~ Potential(f(beta, alpha))",
],
("plain", False): [
r"alpha ~ N",
r"sigma ~ N**+",
r"mu ~ Deterministic",
r"beta ~ N",
r"Z ~ N",
r"nb_with_p_n ~ NB",
r"zip ~ MarginalMixtureRV{inline=False}",
r"Y_obs ~ N",
r"pot ~ Potential",
],
("latex", True): [
r"$\text{alpha} \sim \operatorname{N}(0,~10)$",
r"$\text{sigma} \sim \operatorname{N^{+}}(0,~1)$",
r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$",
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
],
("latex", False): [
r"$\text{alpha} \sim \operatorname{N}$",
r"$\text{sigma} \sim \operatorname{N^{+}}$",
r"$\text{mu} \sim \operatorname{Deterministic}$",
r"$\text{beta} \sim \operatorname{N}$",
r"$\text{Z} \sim \operatorname{N}$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}$",
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
r"$\text{Y_obs} \sim \operatorname{N}$",
r"$\text{pot} \sim \operatorname{Potential}$",
],
}

def test__repr_latex_(self):
for distribution, tex in zip(self.distributions, self.expected[("latex", True)]):
assert distribution._repr_latex_() == tex

model_tex = self.model._repr_latex_()

# make sure each variable is in the model
for tex in self.expected[("latex", True)]:
for segment in tex.strip("$").split(r"\sim"):
assert segment in model_tex

def test_str_repr(self):
for str_format in self.formats:
for dist, text in zip(self.distributions, self.expected[str_format]):
assert dist.str_repr(*str_format) == text

model_text = self.model.str_repr(*str_format)
for text in self.expected[str_format]:
if str_format[0] == "latex":
for segment in text.strip("$").split(r"\sim"):
assert segment in model_text
else:
assert text in model_text