|
32 | 32 | from aesara.graph import node_rewriter
|
33 | 33 | from aesara.graph.basic import Node, Variable, clone_replace
|
34 | 34 | from aesara.graph.rewriting.basic import in2out
|
| 35 | +from aesara.graph.utils import MetaType |
35 | 36 | from aesara.tensor.basic import as_tensor_variable
|
36 | 37 | from aesara.tensor.random.op import RandomVariable
|
37 | 38 | from aesara.tensor.random.type import RandomType
|
|
42 | 43 | from pymc.distributions.shape_utils import (
|
43 | 44 | Dims,
|
44 | 45 | Shape,
|
45 |
| - Size, |
46 | 46 | StrongDims,
|
47 | 47 | StrongShape,
|
48 | 48 | change_dist_size,
|
|
60 | 60 | "DensityDistRV",
|
61 | 61 | "DensityDist",
|
62 | 62 | "Distribution",
|
63 |
| - "SymbolicDistribution", |
64 | 63 | "Continuous",
|
65 | 64 | "Discrete",
|
66 | 65 | "NoDistribution",
|
@@ -112,6 +111,7 @@ def _random(*args, **kwargs):
|
112 | 111 |
|
113 | 112 | if isinstance(rv_op, RandomVariable):
|
114 | 113 | rv_type = type(rv_op)
|
| 114 | + clsdict["rv_type"] = rv_type |
115 | 115 |
|
116 | 116 | new_cls = super().__new__(cls, name, bases, clsdict)
|
117 | 117 |
|
@@ -232,8 +232,8 @@ def update(self, node: Node):
|
232 | 232 | class Distribution(metaclass=DistributionMeta):
|
233 | 233 | """Statistical distribution"""
|
234 | 234 |
|
235 |
| - rv_class = None |
236 |
| - rv_op: RandomVariable = None |
| 235 | + rv_op: [RandomVariable, SymbolicRandomVariable] = None |
| 236 | + rv_type: MetaType = None |
237 | 237 |
|
238 | 238 | def __new__(
|
239 | 239 | cls,
|
@@ -321,7 +321,7 @@ def __new__(
|
321 | 321 | # Resize variable based on `dims` information
|
322 | 322 | if resize_shape_from_dims:
|
323 | 323 | resize_size_from_dims = find_size(
|
324 |
| - shape=resize_shape_from_dims, size=None, ndim_supp=cls.rv_op.ndim_supp |
| 324 | + shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp |
325 | 325 | )
|
326 | 326 | rv_out = change_dist_size(dist=rv_out, new_size=resize_size_from_dims, expand=False)
|
327 | 327 |
|
@@ -397,202 +397,10 @@ def dist(
|
397 | 397 | shape = convert_shape(shape)
|
398 | 398 | size = convert_size(size)
|
399 | 399 |
|
400 |
| - create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp) |
401 |
| - rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) |
402 |
| - |
403 |
| - rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
404 |
| - rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") |
405 |
| - rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)") |
406 |
| - return rv_out |
407 |
| - |
408 |
| - |
409 |
| -class SymbolicDistribution: |
410 |
| - """Symbolic statistical distribution |
411 |
| -
|
412 |
| - While traditional PyMC distributions are represented by a single RandomVariable |
413 |
| - graph, Symbolic distributions correspond to a larger graph that contains one or |
414 |
| - more RandomVariables and an arbitrary number of deterministic operations, which |
415 |
| - represent their own kind of distribution. |
416 |
| -
|
417 |
| - The graphs returned by symbolic distributions can be evaluated directly to |
418 |
| - obtain valid draws and can further be parsed by Aeppl to derive the |
419 |
| - corresponding logp at runtime. |
420 |
| -
|
421 |
| - Check pymc.distributions.Censored for an example of a symbolic distribution. |
422 |
| -
|
423 |
| - Symbolic distributions must implement the following classmethods: |
424 |
| - cls.dist |
425 |
| - Performs input validation and converts optional alternative parametrizations |
426 |
| - to a canonical parametrization. It should call `super().dist()`, passing a |
427 |
| - list with the default parameters as the first and only non keyword argument, |
428 |
| - followed by other keyword arguments like size and rngs, and return the result |
429 |
| - cls.ndim_supp |
430 |
| - Returns the support of the symbolic distribution, given the default set of |
431 |
| - parameters. This may not always be constant, for instance if the symbolic |
432 |
| - distribution can be defined based on an arbitrary base distribution. |
433 |
| - cls.rv_op |
434 |
| - Returns a TensorVariable that represents the symbolic distribution |
435 |
| - parametrized by a default set of parameters and a size and rngs arguments |
436 |
| - """ |
437 |
| - |
438 |
| - def __new__( |
439 |
| - cls, |
440 |
| - name: str, |
441 |
| - *args, |
442 |
| - dims: Optional[Dims] = None, |
443 |
| - initval=None, |
444 |
| - observed=None, |
445 |
| - total_size=None, |
446 |
| - transform=UNSET, |
447 |
| - **kwargs, |
448 |
| - ) -> TensorVariable: |
449 |
| - """Adds a TensorVariable corresponding to a PyMC symbolic distribution to the |
450 |
| - current model. |
451 |
| -
|
452 |
| - Parameters |
453 |
| - ---------- |
454 |
| - cls : type |
455 |
| - A distribution class that inherits from SymbolicDistribution. |
456 |
| - name : str |
457 |
| - Name for the new model variable. |
458 |
| - dims : tuple, optional |
459 |
| - A tuple of dimension names known to the model. When shape is not provided, |
460 |
| - the shape of dims is used to define the shape of the variable. |
461 |
| - initval : optional |
462 |
| - Numeric or symbolic untransformed initial value of matching shape, |
463 |
| - or one of the following initial value strategies: "moment", "prior". |
464 |
| - Depending on the sampler's settings, a random jitter may be added to numeric, |
465 |
| - symbolic or moment-based initial values in the transformed space. |
466 |
| - observed : optional |
467 |
| - Observed data to be passed when registering the random variable in the model. |
468 |
| - When neither shape nor dims is provided, the shape of observed is used to |
469 |
| - define the shape of the variable. |
470 |
| - See ``Model.register_rv``. |
471 |
| - total_size : float, optional |
472 |
| - See ``Model.register_rv``. |
473 |
| - transform : optional |
474 |
| - See ``Model.register_rv``. |
475 |
| - **kwargs |
476 |
| - Keyword arguments that will be forwarded to ``.dist()``. |
477 |
| - Most prominently: ``shape`` and ``size`` |
478 |
| -
|
479 |
| - Returns |
480 |
| - ------- |
481 |
| - var : TensorVariable |
482 |
| - The created variable, registered in the Model. |
483 |
| - """ |
484 |
| - |
485 |
| - try: |
486 |
| - from pymc.model import Model |
487 |
| - |
488 |
| - model = Model.get_context() |
489 |
| - except TypeError: |
490 |
| - raise TypeError( |
491 |
| - "No model on context stack, which is needed to " |
492 |
| - "instantiate distributions. Add variable inside " |
493 |
| - "a 'with model:' block, or use the '.dist' syntax " |
494 |
| - "for a standalone distribution." |
495 |
| - ) |
496 |
| - |
497 |
| - if "testval" in kwargs: |
498 |
| - initval = kwargs.pop("testval") |
499 |
| - warnings.warn( |
500 |
| - "The `testval` argument is deprecated; use `initval`.", |
501 |
| - FutureWarning, |
502 |
| - stacklevel=2, |
503 |
| - ) |
504 |
| - |
505 |
| - if not isinstance(name, string_types): |
506 |
| - raise TypeError(f"Name needs to be a string but got: {name}") |
507 |
| - |
508 |
| - dims = convert_dims(dims) |
509 |
| - if observed is not None: |
510 |
| - observed = convert_observed_data(observed) |
511 |
| - |
512 |
| - # Create the RV, without taking `dims` into consideration |
513 |
| - rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims( |
514 |
| - cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs |
515 |
| - ) |
516 |
| - |
517 |
| - # Resize variable based on `dims` information |
518 |
| - if resize_shape_from_dims: |
519 |
| - resize_size_from_dims = find_size( |
520 |
| - shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp |
521 |
| - ) |
522 |
| - rv_out = change_dist_size(rv_out, new_size=resize_size_from_dims, expand=False) |
523 |
| - |
524 |
| - rv_out = model.register_rv( |
525 |
| - rv_out, |
526 |
| - name, |
527 |
| - observed, |
528 |
| - total_size, |
529 |
| - dims=dims, |
530 |
| - transform=transform, |
531 |
| - initval=initval, |
532 |
| - ) |
533 |
| - # add in pretty-printing support |
534 |
| - rv_out.str_repr = types.MethodType(str_for_dist, rv_out) |
535 |
| - rv_out._repr_latex_ = types.MethodType( |
536 |
| - functools.partial(str_for_dist, formatting="latex"), rv_out |
537 |
| - ) |
538 |
| - |
539 |
| - rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
540 |
| - rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") |
541 |
| - rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)") |
542 |
| - |
543 |
| - return rv_out |
544 |
| - |
545 |
| - @classmethod |
546 |
| - def dist( |
547 |
| - cls, |
548 |
| - dist_params, |
549 |
| - *, |
550 |
| - shape: Optional[Shape] = None, |
551 |
| - size: Optional[Size] = None, |
552 |
| - **kwargs, |
553 |
| - ) -> TensorVariable: |
554 |
| - """Creates a TensorVariable corresponding to the `cls` symbolic distribution. |
555 |
| -
|
556 |
| - Parameters |
557 |
| - ---------- |
558 |
| - dist_params : array-like |
559 |
| - The inputs to the `RandomVariable` `Op`. |
560 |
| - shape : int, tuple, Variable, optional |
561 |
| - A tuple of sizes for each dimension of the new RV. |
562 |
| - size : int, tuple, Variable, optional |
563 |
| - For creating the RV like in Aesara/NumPy. |
564 |
| -
|
565 |
| - Returns |
566 |
| - ------- |
567 |
| - var : TensorVariable |
568 |
| - """ |
569 |
| - |
570 |
| - if "testval" in kwargs: |
571 |
| - kwargs.pop("testval") |
572 |
| - warnings.warn( |
573 |
| - "The `.dist(testval=...)` argument is deprecated and has no effect. " |
574 |
| - "Initial values for sampling/optimization can be specified with `initval` in a modelcontext. " |
575 |
| - "For using Aesara's test value features, you must assign the `.tag.test_value` yourself.", |
576 |
| - FutureWarning, |
577 |
| - stacklevel=2, |
578 |
| - ) |
579 |
| - if "initval" in kwargs: |
580 |
| - raise TypeError( |
581 |
| - "Unexpected keyword argument `initval`. " |
582 |
| - "This argument is not available for the `.dist()` API." |
583 |
| - ) |
584 |
| - |
585 |
| - if "dims" in kwargs: |
586 |
| - raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.") |
587 |
| - if shape is not None and size is not None: |
588 |
| - raise ValueError( |
589 |
| - f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!" |
590 |
| - ) |
591 |
| - |
592 |
| - shape = convert_shape(shape) |
593 |
| - size = convert_size(size) |
594 |
| - |
595 |
| - ndim_supp = cls.ndim_supp(*dist_params) |
| 400 | + # SymbolicRVs don't have `ndim_supp` until they are created |
| 401 | + ndim_supp = getattr(cls.rv_op, "ndim_supp", None) |
| 402 | + if ndim_supp is None: |
| 403 | + ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp |
596 | 404 | create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
|
597 | 405 | rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
|
598 | 406 |
|
|
0 commit comments