From 2db70ce34f91b7139f1ecaaa9bc06bf925736d06 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 13:29:17 +0200 Subject: [PATCH 1/6] Add return type hints Closes #4880 --- pymc/distributions/logprob.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 9b92f8acfb..644d7e84b7 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -39,7 +39,7 @@ from pymc.aesaraf import floatX -def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int): +def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int) -> TensorVariable: """ Gets scaling constant for logp. @@ -288,14 +288,14 @@ def logp(rv: TensorVariable, value) -> TensorVariable: raise NotImplementedError("PyMC could not infer logp of input variable.") from exc -def logcdf(rv, value): +def logcdf(rv, value) -> TensorVariable: """Return the log-cdf graph of a Random Variable""" value = at.as_tensor_variable(value, dtype=rv.dtype) return logcdf_aeppl(rv, value) -def ignore_logprob(rv): +def ignore_logprob(rv) -> TensorVariable: """Return a duplicated variable that is ignored when creating Aeppl logprob graphs This is used in SymbolicDistributions that use other RVs as inputs but account From e8571b2160e39b0c6484628eea47b628bd3be2a4 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 13:53:19 +0200 Subject: [PATCH 2/6] Fix pre-commit --- pymc/distributions/logprob.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 644d7e84b7..0051aeff66 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -39,7 +39,9 @@ from pymc.aesaraf import floatX -def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int) -> TensorVariable: +def _get_scaling( + total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int +) -> TensorVariable: """ Gets scaling constant for logp. From 7f32c303423b8df9d85802d31152bcd5585e4949 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 17:38:58 +0200 Subject: [PATCH 3/6] Add type hints for rv arguments Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/logprob.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 0051aeff66..fb2b041ff8 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -290,14 +290,14 @@ def logp(rv: TensorVariable, value) -> TensorVariable: raise NotImplementedError("PyMC could not infer logp of input variable.") from exc -def logcdf(rv, value) -> TensorVariable: +def logcdf(rv: TensorVariable, value) -> TensorVariable: """Return the log-cdf graph of a Random Variable""" value = at.as_tensor_variable(value, dtype=rv.dtype) return logcdf_aeppl(rv, value) -def ignore_logprob(rv) -> TensorVariable: +def ignore_logprob(rv: TensorVariable) -> TensorVariable: """Return a duplicated variable that is ignored when creating Aeppl logprob graphs This is used in SymbolicDistributions that use other RVs as inputs but account From 22cec6b12be9a95016074db1f5cff7bcd256f30b Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 18:07:12 +0200 Subject: [PATCH 4/6] Fix distribution return type hints --- pymc/distributions/distribution.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 35fce9e222..22d7662575 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -198,8 +198,8 @@ def __new__( total_size=None, transform=UNSET, **kwargs, - ) -> RandomVariable: - """Adds a RandomVariable corresponding to a PyMC distribution to the current model. + ) -> TensorVariable: + """Adds a tensor variable corresponding to a PyMC distribution to the current model. Note that all remaining kwargs must be compatible with ``.dist()`` @@ -231,8 +231,8 @@ def __new__( Returns ------- - rv : RandomVariable - The created RV, registered in the Model. + rv : TensorVariable + The created random variable tensor, registered in the Model. """ try: @@ -296,8 +296,8 @@ def dist( *, shape: Optional[Shape] = None, **kwargs, - ) -> RandomVariable: - """Creates a RandomVariable corresponding to the `cls` distribution. + ) -> TensorVariable: + """Creates a tensor variable corresponding to the `cls` distribution. Parameters ---------- @@ -314,8 +314,8 @@ def dist( Returns ------- - rv : RandomVariable - The created RV. + rv : TensorVariable + The created random variable tensor. """ if "testval" in kwargs: kwargs.pop("testval") @@ -653,8 +653,8 @@ def __new__( name : str dist_params : Tuple A sequence of the distribution's parameter. These will be converted into - Aesara tensors internally. These parameters could be other ``RandomVariable`` - instances. + Aesara tensors internally. These parameters could be other ``TensorVariable`` + instances created from , optionally created via ``RandomVariable`` ``Op``s. logp : Optional[Callable] A callable that calculates the log density of some given observed ``value`` conditioned on certain distribution parameter values. It must have the From 98d767a0427a4ea4965dd3f8219f363af1e0a17a Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 18:28:58 +0200 Subject: [PATCH 5/6] Fix more type issues --- pymc/aesaraf.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 0160635d0b..967267af8a 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -52,6 +52,10 @@ from aesara.scalar.basic import Cast from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.var import ( + RandomGeneratorSharedVariable, + RandomStateSharedVariable, +) from aesara.tensor.shape import SpecifyShape from aesara.tensor.sharedvar import SharedVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -60,9 +64,7 @@ from pymc.exceptions import ShapeError from pymc.vartypes import continuous_types, isgenerator, typefilter -PotentialShapeType = Union[ - int, np.ndarray, Tuple[Union[int, Variable], ...], List[Union[int, Variable]], Variable -] +PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable] __all__ = [ @@ -165,6 +167,7 @@ def change_rv_size( new_size = (new_size,) # Extract the RV node that is to be resized, together with its inputs, name and tag + assert rw.owner.op is not None if isinstance(rv.owner.op, SpecifyShape): rv = rv.owner.inputs[0] rv_node = rv.owner @@ -894,18 +897,14 @@ def local_check_parameter_to_ninf_switch(fgraph, node): ) -def find_rng_nodes(variables: Iterable[TensorVariable]): +def find_rng_nodes( + variables: Iterable[Variable], +) -> List[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]: """Return RNG variables in a graph""" return [ node for node in graph_inputs(variables) - if isinstance( - node, - ( - at.random.var.RandomStateSharedVariable, - at.random.var.RandomGeneratorSharedVariable, - ), - ) + if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) ] @@ -921,6 +920,7 @@ def reseed_rngs( np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) ] for rng, bit_generator in zip(rngs, bit_generators): + new_rng: Union[np.random.RandomState, np.random.Generator] if isinstance(rng, at.random.var.RandomStateSharedVariable): new_rng = np.random.RandomState(bit_generator) else: @@ -980,6 +980,9 @@ def compile_pymc( and isinstance(var.owner.op, (RandomVariable, MeasurableVariable)) and var not in inputs ): + # All nodes in `vars_between(inputs, outputs)` have owners. + # But mypy doesn't know, so we just assert it: + assert random_var.owner.op is not None if isinstance(random_var.owner.op, RandomVariable): rng = random_var.owner.inputs[0] if not hasattr(rng, "default_update"): From bb569e2dd4dd9cf3acfb46813eca2c7c4b041d0d Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 29 May 2022 19:08:09 +0200 Subject: [PATCH 6/6] Fix typo --- pymc/aesaraf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 967267af8a..a3d19659c6 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -167,7 +167,7 @@ def change_rv_size( new_size = (new_size,) # Extract the RV node that is to be resized, together with its inputs, name and tag - assert rw.owner.op is not None + assert rv.owner.op is not None if isinstance(rv.owner.op, SpecifyShape): rv = rv.owner.inputs[0] rv_node = rv.owner