Skip to content

Latex representation for SymbolicDistributions #5793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.distribution import Discrete, set_print_name
from pymc.distributions.logprob import logp
from pymc.distributions.mixture import Mixture
from pymc.distributions.shape_utils import rv_size_is_none
Expand Down Expand Up @@ -1411,7 +1411,7 @@ def logcdf(value, c):
)


def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
def _zero_inflated_mixture(*, cls, name, nonzero_p, nonzero_dist, **kwargs):
"""Helper function to create a zero-inflated mixture

If name is `None`, this function returns an unregistered variable
Expand All @@ -1423,9 +1423,14 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
nonzero_dist,
]
if name is not None:
return Mixture(name, weights, comp_dists, **kwargs)
rv_out = Mixture(name, weights, comp_dists, **kwargs)
else:
return Mixture.dist(weights, comp_dists, **kwargs)
rv_out = Mixture.dist(weights, comp_dists, **kwargs)

# overriding Mixture _print_name
set_print_name(cls, rv_out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to be more clever about mixture print name and on the spot check:

  1. If it has 2 components and one is a zero constant, call it "ZeroInflatedX", where "X" is the name of the nonzero component, like "ZeroInflatedBinomial"
  2. If all the components follow the same distribution call it "XMixture" like "NormalMixture" or "GammaMixture"
  3. If there are two different components, call it "X-YMixture" like "Gamma-ExponentialMixture"
  4. Otherwise call it simply Mixture

This is where the dispatching idea becomes a bit more powerful. You can have an arbitrarily complex function that specializes on an Op and can look at its inputs at evaluation time to figure out a nice name.

The same type of logic could be used for Censored (dispatched on at.clip) to call it "CensoredX" like "CensoredNormal", or the future RandomWalk to call them "XRandomWalk" (perhaps with a special case for Normal, where we call it "Gaussian")

The dispatching itself only means we don't need to define the name at creation time (and others can overwrite it more easily). The more fundamental difference is that it uses a per-op function to decide what name to give to the distribution.

The basecase would be the RandomVariable Op which does what was already done eagerly before this PR.

This is just an idea. Feel free to investigate something like this or leave it as a separate enhancement/feature request issue!


return rv_out


class ZeroInflatedPoisson:
Expand Down Expand Up @@ -1481,13 +1486,13 @@ class ZeroInflatedPoisson:

def __new__(cls, name, psi, mu, **kwargs):
return _zero_inflated_mixture(
name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs
cls=cls, name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs
)

@classmethod
def dist(cls, psi, mu, **kwargs):
return _zero_inflated_mixture(
name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs
cls=cls, name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs
)


Expand Down Expand Up @@ -1545,13 +1550,13 @@ class ZeroInflatedBinomial:

def __new__(cls, name, psi, n, p, **kwargs):
return _zero_inflated_mixture(
name=name, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
cls=cls, name=name, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
)

@classmethod
def dist(cls, psi, n, p, **kwargs):
return _zero_inflated_mixture(
name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
cls=cls, name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
)


Expand Down Expand Up @@ -1638,6 +1643,7 @@ def ZeroInfNegBinom(a, m, psi, x):

def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
return _zero_inflated_mixture(
cls=cls,
name=name,
nonzero_p=psi,
nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n),
Expand All @@ -1647,6 +1653,7 @@ def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
@classmethod
def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
return _zero_inflated_mixture(
cls=cls,
name=None,
nonzero_p=psi,
nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n),
Expand Down
21 changes: 11 additions & 10 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
resize_from_dims,
resize_from_observed,
)
from pymc.printing import str_for_dist
from pymc.printing import str_for_dist, str_for_symbolic_dist
from pymc.util import UNSET
from pymc.vartypes import string_types

Expand Down Expand Up @@ -484,15 +484,13 @@ def __new__(
initval=initval,
)

# TODO: Refactor this
# add in pretty-printing support
rv_out.str_repr = lambda *args, **kwargs: name
rv_out._repr_latex_ = f"\\text{name}"
# rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
# rv_out._repr_latex_ = types.MethodType(
# functools.partial(str_for_dist, formatting="latex"), rv_out
# )
# TODO: if rv_out is a cloned variable, the line below wouldn't work
set_print_name(cls, rv_out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we clone a variable (e.g., during graph rewrite) this will be lost right? It's not a big issue for now, but of that's the case we should add a comment stating so


rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
)
return rv_out

@classmethod
Expand Down Expand Up @@ -521,7 +519,6 @@ def dist(
-------
var : TensorVariable
"""

if "testval" in kwargs:
kwargs.pop("testval")
warnings.warn(
Expand Down Expand Up @@ -848,3 +845,7 @@ def default_moment(rv, size, *rv_inputs, rv_name=None, has_fallback=False, ndim_
f"Please provide a moment function when instantiating the {rv_name} "
"random variable."
)


def set_print_name(cls, rv):
setattr(rv.owner.op, "_print_name", (f"{cls.__name__}", f"\\operatorname{{{cls.__name__}}}"))
99 changes: 97 additions & 2 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import itertools
import warnings

from typing import Union

from aesara.graph.basic import walk
from aesara.tensor.basic import TensorVariable, Variable
from aesara.tensor.basic import MakeVector, TensorVariable, Variable
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import RandomVariable
from aesara.tensor.var import TensorConstant
Expand All @@ -26,6 +27,7 @@

__all__ = [
"str_for_dist",
"str_for_symbolic_dist",
"str_for_model",
"str_for_potential_or_deterministic",
]
Expand Down Expand Up @@ -55,6 +57,92 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params:
return rf"{print_name} ~ {dist_name}"


def str_for_symbolic_dist(
rv: TensorVariable, formatting: str = "plain", include_params: bool = True
) -> str:
def dispatch_comp_str(var, formatting=formatting, include_params=include_params):
if var.name:
return var.name
if isinstance(var, TensorConstant):
if len(var.data.shape) > 1:
raise NotImplementedError
try:
if var.data.shape[0] > 1:
# weights in mixture model
return "[" + ",".join([str(weight) for weight in var.data]) + "]"
except IndexError:
# just a scalar
return _str_for_constant(var, formatting)
if isinstance(var.owner.op, MakeVector):
# psi in some zero inflated distribution
return dispatch_comp_str(var.owner.inputs[1])

# else it's a Mixture component initialized by the .dist() API

dist_args = ", ".join(
[_str_for_input_var(x, formatting=formatting) for x in var.owner.inputs[3:]]
)
comp_name = var.owner.op.name.capitalize()

if "latex" in formatting:
comp_name = r"\text{" + _latex_escape(comp_name) + "}"

return f"{comp_name}({dist_args})"

if include_params:
if "ZeroInflated" in rv.owner.op._print_name[0]:
# position 2 is just a constant_rv{0, (0,), shape, False}.1
assert rv.owner.inputs[2].owner.op.__class__.__name__ == "UnmeasurableConstantRV"
dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[3:]

elif "Mixture" in rv.owner.op._print_name[0]:

if len(rv.owner.inputs) == 3:
# is a single component!
# (rng, weights, single_component)
rv.owner.op._print_name = (
f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture",
"\\operatorname{" + f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture}}",
)
dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[2].owner.inputs[3:]
else:
dist_parameters = rv.owner.inputs[1:]

elif "Censored" in rv.owner.op._print_name[0]:
dist_parameters = rv.owner.inputs
else:
# Latex representation for the SymbolicDistribution has not been implemented.
# Hoping for the best here!
dist_parameters = rv.owner.inputs[2:]
warnings.warn(
"Latex representation for this SymbolicDistribution has not been implemented. "
"Please have a look at str_for_symbolic_dist in pymc/printing.py",
FutureWarning,
stacklevel=2,
)

dist_args = [
dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params)
for dist_para in dist_parameters
]

# code below copied from str_for_dist
print_name = rv.name if rv.name is not None else "<unnamed>"
if "latex" in formatting:
print_name = r"\text{" + _latex_escape(print_name) + "}"
dist_name = rv.owner.op._print_name[1]
if include_params:
return r"${} \sim {}({})$".format(print_name, dist_name, ",~".join(dist_args))
else:
return rf"${print_name} \sim {dist_name}$"
else: # plain
dist_name = rv.owner.op._print_name[0]
if include_params:
return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args))
else:
return rf"{print_name} ~ {dist_name}"


def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
"""Make a human-readable string representation of Model, listing all random variables
and their distributions, optionally including parameter values."""
Expand Down Expand Up @@ -133,7 +221,14 @@ def _is_potential_or_determinstic(var: Variable) -> bool:


def _str_for_input_rv(var: Variable, formatting: str) -> str:
_str = var.name if var.name is not None else "<unnamed>"

if var.name:
_str = var.name
elif var.owner.op.name:
_str = var.owner.op.name.capitalize()
else:
_str = "<unnamed>"

if "latex" in formatting:
return r"\text{" + _latex_escape(_str) + "}"
else:
Expand Down