Skip to content

Commit a74d193

Browse files
committed
Deprecate SymbolicDistribution in favor of Distribution class
* Also fixes index error in AR when init_dist is scalar and size is None * Remove unused rv_class attribute
1 parent 05bba72 commit a74d193

File tree

7 files changed

+24
-221
lines changed

7 files changed

+24
-221
lines changed

docs/source/api/distributions/utilities.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ Distribution utilities
77
:toctree: generated/
88

99
Distribution
10-
SymbolicDistribution
1110
Discrete
1211
Continuous
1312
NoDistribution

pymc/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
Discrete,
8383
Distribution,
8484
NoDistribution,
85-
SymbolicDistribution,
8685
SymbolicRandomVariable,
8786
)
8887
from pymc.distributions.mixture import Mixture, NormalMixture
@@ -156,7 +155,6 @@
156155
"OrderedProbit",
157156
"DensityDist",
158157
"Distribution",
159-
"SymbolicDistribution",
160158
"SymbolicRandomVariable",
161159
"Continuous",
162160
"Discrete",

pymc/distributions/censored.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aesara.tensor.random.op import RandomVariable
1919

2020
from pymc.distributions.distribution import (
21-
SymbolicDistribution,
21+
Distribution,
2222
SymbolicRandomVariable,
2323
_moment,
2424
)
@@ -33,7 +33,7 @@ class CensoredRV(SymbolicRandomVariable):
3333
_print_name = ("Censored", "\\operatorname{Censored}")
3434

3535

36-
class Censored(SymbolicDistribution):
36+
class Censored(Distribution):
3737
r"""
3838
Censored distribution
3939
@@ -82,6 +82,8 @@ class Censored(SymbolicDistribution):
8282
censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1)
8383
"""
8484

85+
rv_type = CensoredRV
86+
8587
@classmethod
8688
def dist(cls, dist, lower, upper, **kwargs):
8789
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
@@ -95,10 +97,6 @@ def dist(cls, dist, lower, upper, **kwargs):
9597
check_dist_not_registered(dist)
9698
return super().dist([dist, lower, upper], **kwargs)
9799

98-
@classmethod
99-
def ndim_supp(cls, *dist_params):
100-
return 0
101-
102100
@classmethod
103101
def rv_op(cls, dist, lower=None, upper=None, size=None):
104102

pymc/distributions/distribution.py

Lines changed: 9 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from aesara.graph import node_rewriter
3333
from aesara.graph.basic import Node, Variable, clone_replace
3434
from aesara.graph.rewriting.basic import in2out
35+
from aesara.graph.utils import MetaType
3536
from aesara.tensor.basic import as_tensor_variable
3637
from aesara.tensor.random.op import RandomVariable
3738
from aesara.tensor.random.type import RandomType
@@ -42,7 +43,6 @@
4243
from pymc.distributions.shape_utils import (
4344
Dims,
4445
Shape,
45-
Size,
4646
StrongDims,
4747
StrongShape,
4848
change_dist_size,
@@ -60,7 +60,6 @@
6060
"DensityDistRV",
6161
"DensityDist",
6262
"Distribution",
63-
"SymbolicDistribution",
6463
"Continuous",
6564
"Discrete",
6665
"NoDistribution",
@@ -112,6 +111,7 @@ def _random(*args, **kwargs):
112111

113112
if isinstance(rv_op, RandomVariable):
114113
rv_type = type(rv_op)
114+
clsdict["rv_type"] = rv_type
115115

116116
new_cls = super().__new__(cls, name, bases, clsdict)
117117

@@ -232,8 +232,8 @@ def update(self, node: Node):
232232
class Distribution(metaclass=DistributionMeta):
233233
"""Statistical distribution"""
234234

235-
rv_class = None
236-
rv_op: RandomVariable = None
235+
rv_op: [RandomVariable, SymbolicRandomVariable] = None
236+
rv_type: MetaType = None
237237

238238
def __new__(
239239
cls,
@@ -321,7 +321,7 @@ def __new__(
321321
# Resize variable based on `dims` information
322322
if resize_shape_from_dims:
323323
resize_size_from_dims = find_size(
324-
shape=resize_shape_from_dims, size=None, ndim_supp=cls.rv_op.ndim_supp
324+
shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp
325325
)
326326
rv_out = change_dist_size(dist=rv_out, new_size=resize_size_from_dims, expand=False)
327327

@@ -397,202 +397,10 @@ def dist(
397397
shape = convert_shape(shape)
398398
size = convert_size(size)
399399

400-
create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp)
401-
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
402-
403-
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
404-
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
405-
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
406-
return rv_out
407-
408-
409-
class SymbolicDistribution:
410-
"""Symbolic statistical distribution
411-
412-
While traditional PyMC distributions are represented by a single RandomVariable
413-
graph, Symbolic distributions correspond to a larger graph that contains one or
414-
more RandomVariables and an arbitrary number of deterministic operations, which
415-
represent their own kind of distribution.
416-
417-
The graphs returned by symbolic distributions can be evaluated directly to
418-
obtain valid draws and can further be parsed by Aeppl to derive the
419-
corresponding logp at runtime.
420-
421-
Check pymc.distributions.Censored for an example of a symbolic distribution.
422-
423-
Symbolic distributions must implement the following classmethods:
424-
cls.dist
425-
Performs input validation and converts optional alternative parametrizations
426-
to a canonical parametrization. It should call `super().dist()`, passing a
427-
list with the default parameters as the first and only non keyword argument,
428-
followed by other keyword arguments like size and rngs, and return the result
429-
cls.ndim_supp
430-
Returns the support of the symbolic distribution, given the default set of
431-
parameters. This may not always be constant, for instance if the symbolic
432-
distribution can be defined based on an arbitrary base distribution.
433-
cls.rv_op
434-
Returns a TensorVariable that represents the symbolic distribution
435-
parametrized by a default set of parameters and a size and rngs arguments
436-
"""
437-
438-
def __new__(
439-
cls,
440-
name: str,
441-
*args,
442-
dims: Optional[Dims] = None,
443-
initval=None,
444-
observed=None,
445-
total_size=None,
446-
transform=UNSET,
447-
**kwargs,
448-
) -> TensorVariable:
449-
"""Adds a TensorVariable corresponding to a PyMC symbolic distribution to the
450-
current model.
451-
452-
Parameters
453-
----------
454-
cls : type
455-
A distribution class that inherits from SymbolicDistribution.
456-
name : str
457-
Name for the new model variable.
458-
dims : tuple, optional
459-
A tuple of dimension names known to the model. When shape is not provided,
460-
the shape of dims is used to define the shape of the variable.
461-
initval : optional
462-
Numeric or symbolic untransformed initial value of matching shape,
463-
or one of the following initial value strategies: "moment", "prior".
464-
Depending on the sampler's settings, a random jitter may be added to numeric,
465-
symbolic or moment-based initial values in the transformed space.
466-
observed : optional
467-
Observed data to be passed when registering the random variable in the model.
468-
When neither shape nor dims is provided, the shape of observed is used to
469-
define the shape of the variable.
470-
See ``Model.register_rv``.
471-
total_size : float, optional
472-
See ``Model.register_rv``.
473-
transform : optional
474-
See ``Model.register_rv``.
475-
**kwargs
476-
Keyword arguments that will be forwarded to ``.dist()``.
477-
Most prominently: ``shape`` and ``size``
478-
479-
Returns
480-
-------
481-
var : TensorVariable
482-
The created variable, registered in the Model.
483-
"""
484-
485-
try:
486-
from pymc.model import Model
487-
488-
model = Model.get_context()
489-
except TypeError:
490-
raise TypeError(
491-
"No model on context stack, which is needed to "
492-
"instantiate distributions. Add variable inside "
493-
"a 'with model:' block, or use the '.dist' syntax "
494-
"for a standalone distribution."
495-
)
496-
497-
if "testval" in kwargs:
498-
initval = kwargs.pop("testval")
499-
warnings.warn(
500-
"The `testval` argument is deprecated; use `initval`.",
501-
FutureWarning,
502-
stacklevel=2,
503-
)
504-
505-
if not isinstance(name, string_types):
506-
raise TypeError(f"Name needs to be a string but got: {name}")
507-
508-
dims = convert_dims(dims)
509-
if observed is not None:
510-
observed = convert_observed_data(observed)
511-
512-
# Create the RV, without taking `dims` into consideration
513-
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
514-
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
515-
)
516-
517-
# Resize variable based on `dims` information
518-
if resize_shape_from_dims:
519-
resize_size_from_dims = find_size(
520-
shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp
521-
)
522-
rv_out = change_dist_size(rv_out, new_size=resize_size_from_dims, expand=False)
523-
524-
rv_out = model.register_rv(
525-
rv_out,
526-
name,
527-
observed,
528-
total_size,
529-
dims=dims,
530-
transform=transform,
531-
initval=initval,
532-
)
533-
# add in pretty-printing support
534-
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
535-
rv_out._repr_latex_ = types.MethodType(
536-
functools.partial(str_for_dist, formatting="latex"), rv_out
537-
)
538-
539-
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
540-
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
541-
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
542-
543-
return rv_out
544-
545-
@classmethod
546-
def dist(
547-
cls,
548-
dist_params,
549-
*,
550-
shape: Optional[Shape] = None,
551-
size: Optional[Size] = None,
552-
**kwargs,
553-
) -> TensorVariable:
554-
"""Creates a TensorVariable corresponding to the `cls` symbolic distribution.
555-
556-
Parameters
557-
----------
558-
dist_params : array-like
559-
The inputs to the `RandomVariable` `Op`.
560-
shape : int, tuple, Variable, optional
561-
A tuple of sizes for each dimension of the new RV.
562-
size : int, tuple, Variable, optional
563-
For creating the RV like in Aesara/NumPy.
564-
565-
Returns
566-
-------
567-
var : TensorVariable
568-
"""
569-
570-
if "testval" in kwargs:
571-
kwargs.pop("testval")
572-
warnings.warn(
573-
"The `.dist(testval=...)` argument is deprecated and has no effect. "
574-
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
575-
"For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
576-
FutureWarning,
577-
stacklevel=2,
578-
)
579-
if "initval" in kwargs:
580-
raise TypeError(
581-
"Unexpected keyword argument `initval`. "
582-
"This argument is not available for the `.dist()` API."
583-
)
584-
585-
if "dims" in kwargs:
586-
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
587-
if shape is not None and size is not None:
588-
raise ValueError(
589-
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
590-
)
591-
592-
shape = convert_shape(shape)
593-
size = convert_size(size)
594-
595-
ndim_supp = cls.ndim_supp(*dist_params)
400+
# SymbolicRVs don't have `ndim_supp` until they are created
401+
ndim_supp = getattr(cls.rv_op, "ndim_supp", None)
402+
if ndim_supp is None:
403+
ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
596404
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
597405
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
598406

pymc/distributions/mixture.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pymc.distributions.continuous import Normal, get_tau_sigma
2828
from pymc.distributions.dist_math import check_parameters
2929
from pymc.distributions.distribution import (
30-
SymbolicDistribution,
30+
Distribution,
3131
SymbolicRandomVariable,
3232
_moment,
3333
moment,
@@ -52,7 +52,7 @@ def update(self, node: Node):
5252
return {node.inputs[0]: node.outputs[0]}
5353

5454

55-
class Mixture(SymbolicDistribution):
55+
class Mixture(Distribution):
5656
R"""
5757
Mixture log-likelihood
5858
@@ -161,6 +161,8 @@ class Mixture(SymbolicDistribution):
161161
like = pm.Mixture('like', w=w, comp_dists=components, observed=data)
162162
"""
163163

164+
rv_type = MarginalMixtureRV
165+
164166
@classmethod
165167
def dist(cls, w, comp_dists, **kwargs):
166168
if not isinstance(comp_dists, (tuple, list)):
@@ -205,11 +207,6 @@ def dist(cls, w, comp_dists, **kwargs):
205207
w = at.as_tensor_variable(w)
206208
return super().dist([w, *comp_dists], **kwargs)
207209

208-
@classmethod
209-
def ndim_supp(cls, weights, *components):
210-
# We already checked that all components have the same support dimensionality
211-
return components[0].owner.op.ndim_supp
212-
213210
@classmethod
214211
def rv_op(cls, weights, *components, size=None):
215212
# Create new rng for the mix_indexes internal RV

0 commit comments

Comments
 (0)