diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 98b946337d..642b285de5 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -42,7 +42,7 @@ normal_lccdf, normal_lcdf, ) -from pymc.distributions.distribution import Discrete +from pymc.distributions.distribution import Discrete, set_print_name from pymc.distributions.logprob import logp from pymc.distributions.mixture import Mixture from pymc.distributions.shape_utils import rv_size_is_none @@ -1411,7 +1411,7 @@ def logcdf(value, c): ) -def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs): +def _zero_inflated_mixture(*, cls, name, nonzero_p, nonzero_dist, **kwargs): """Helper function to create a zero-inflated mixture If name is `None`, this function returns an unregistered variable @@ -1423,9 +1423,14 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs): nonzero_dist, ] if name is not None: - return Mixture(name, weights, comp_dists, **kwargs) + rv_out = Mixture(name, weights, comp_dists, **kwargs) else: - return Mixture.dist(weights, comp_dists, **kwargs) + rv_out = Mixture.dist(weights, comp_dists, **kwargs) + + # overriding Mixture _print_name + set_print_name(cls, rv_out) + + return rv_out class ZeroInflatedPoisson: @@ -1481,13 +1486,13 @@ class ZeroInflatedPoisson: def __new__(cls, name, psi, mu, **kwargs): return _zero_inflated_mixture( - name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs + cls=cls, name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs ) @classmethod def dist(cls, psi, mu, **kwargs): return _zero_inflated_mixture( - name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs + cls=cls, name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs ) @@ -1545,13 +1550,13 @@ class ZeroInflatedBinomial: def __new__(cls, name, psi, n, p, **kwargs): return _zero_inflated_mixture( - name=name, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs + cls=cls, name=name, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs ) @classmethod def dist(cls, psi, n, p, **kwargs): return _zero_inflated_mixture( - name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs + cls=cls, name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs ) @@ -1638,6 +1643,7 @@ def ZeroInfNegBinom(a, m, psi, x): def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs): return _zero_inflated_mixture( + cls=cls, name=name, nonzero_p=psi, nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n), @@ -1647,6 +1653,7 @@ def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs): @classmethod def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs): return _zero_inflated_mixture( + cls=cls, name=None, nonzero_p=psi, nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n), diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 22d7662575..d38e505cc1 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -47,7 +47,7 @@ resize_from_dims, resize_from_observed, ) -from pymc.printing import str_for_dist +from pymc.printing import str_for_dist, str_for_symbolic_dist from pymc.util import UNSET from pymc.vartypes import string_types @@ -484,15 +484,13 @@ def __new__( initval=initval, ) - # TODO: Refactor this - # add in pretty-printing support - rv_out.str_repr = lambda *args, **kwargs: name - rv_out._repr_latex_ = f"\\text{name}" - # rv_out.str_repr = types.MethodType(str_for_dist, rv_out) - # rv_out._repr_latex_ = types.MethodType( - # functools.partial(str_for_dist, formatting="latex"), rv_out - # ) + # TODO: if rv_out is a cloned variable, the line below wouldn't work + set_print_name(cls, rv_out) + rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out) + rv_out._repr_latex_ = types.MethodType( + functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out + ) return rv_out @classmethod @@ -521,7 +519,6 @@ def dist( ------- var : TensorVariable """ - if "testval" in kwargs: kwargs.pop("testval") warnings.warn( @@ -848,3 +845,7 @@ def default_moment(rv, size, *rv_inputs, rv_name=None, has_fallback=False, ndim_ f"Please provide a moment function when instantiating the {rv_name} " "random variable." ) + + +def set_print_name(cls, rv): + setattr(rv.owner.op, "_print_name", (f"{cls.__name__}", f"\\operatorname{{{cls.__name__}}}")) diff --git a/pymc/printing.py b/pymc/printing.py index 828d07cee3..ea4555c7ac 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,11 +13,12 @@ # limitations under the License. import itertools +import warnings from typing import Union from aesara.graph.basic import walk -from aesara.tensor.basic import TensorVariable, Variable +from aesara.tensor.basic import MakeVector, TensorVariable, Variable from aesara.tensor.elemwise import DimShuffle from aesara.tensor.random.basic import RandomVariable from aesara.tensor.var import TensorConstant @@ -26,6 +27,7 @@ __all__ = [ "str_for_dist", + "str_for_symbolic_dist", "str_for_model", "str_for_potential_or_deterministic", ] @@ -55,6 +57,92 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params: return rf"{print_name} ~ {dist_name}" +def str_for_symbolic_dist( + rv: TensorVariable, formatting: str = "plain", include_params: bool = True +) -> str: + def dispatch_comp_str(var, formatting=formatting, include_params=include_params): + if var.name: + return var.name + if isinstance(var, TensorConstant): + if len(var.data.shape) > 1: + raise NotImplementedError + try: + if var.data.shape[0] > 1: + # weights in mixture model + return "[" + ",".join([str(weight) for weight in var.data]) + "]" + except IndexError: + # just a scalar + return _str_for_constant(var, formatting) + if isinstance(var.owner.op, MakeVector): + # psi in some zero inflated distribution + return dispatch_comp_str(var.owner.inputs[1]) + + # else it's a Mixture component initialized by the .dist() API + + dist_args = ", ".join( + [_str_for_input_var(x, formatting=formatting) for x in var.owner.inputs[3:]] + ) + comp_name = var.owner.op.name.capitalize() + + if "latex" in formatting: + comp_name = r"\text{" + _latex_escape(comp_name) + "}" + + return f"{comp_name}({dist_args})" + + if include_params: + if "ZeroInflated" in rv.owner.op._print_name[0]: + # position 2 is just a constant_rv{0, (0,), shape, False}.1 + assert rv.owner.inputs[2].owner.op.__class__.__name__ == "UnmeasurableConstantRV" + dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[3:] + + elif "Mixture" in rv.owner.op._print_name[0]: + + if len(rv.owner.inputs) == 3: + # is a single component! + # (rng, weights, single_component) + rv.owner.op._print_name = ( + f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture", + "\\operatorname{" + f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture}}", + ) + dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[2].owner.inputs[3:] + else: + dist_parameters = rv.owner.inputs[1:] + + elif "Censored" in rv.owner.op._print_name[0]: + dist_parameters = rv.owner.inputs + else: + # Latex representation for the SymbolicDistribution has not been implemented. + # Hoping for the best here! + dist_parameters = rv.owner.inputs[2:] + warnings.warn( + "Latex representation for this SymbolicDistribution has not been implemented. " + "Please have a look at str_for_symbolic_dist in pymc/printing.py", + FutureWarning, + stacklevel=2, + ) + + dist_args = [ + dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params) + for dist_para in dist_parameters + ] + + # code below copied from str_for_dist + print_name = rv.name if rv.name is not None else "" + if "latex" in formatting: + print_name = r"\text{" + _latex_escape(print_name) + "}" + dist_name = rv.owner.op._print_name[1] + if include_params: + return r"${} \sim {}({})$".format(print_name, dist_name, ",~".join(dist_args)) + else: + return rf"${print_name} \sim {dist_name}$" + else: # plain + dist_name = rv.owner.op._print_name[0] + if include_params: + return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args)) + else: + return rf"{print_name} ~ {dist_name}" + + def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of Model, listing all random variables and their distributions, optionally including parameter values.""" @@ -133,7 +221,14 @@ def _is_potential_or_determinstic(var: Variable) -> bool: def _str_for_input_rv(var: Variable, formatting: str) -> str: - _str = var.name if var.name is not None else "" + + if var.name: + _str = var.name + elif var.owner.op.name: + _str = var.owner.op.name.capitalize() + else: + _str = "" + if "latex" in formatting: return r"\text{" + _latex_escape(_str) + "}" else: