Skip to content

Fix shape_utils missing functions and misformatted returns #7006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/shape_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ This module introduces functions that are made aware of the requested `size_tupl
to_tuple
rv_size_is_none
change_dist_size
broadcast_dist_samples_shape
80 changes: 54 additions & 26 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings

from functools import singledispatch
from typing import Any, Optional, Sequence, Tuple, Union, cast
from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast

import numpy as np

Expand Down Expand Up @@ -49,18 +49,28 @@
from pymc.util import _add_future_warning_tag


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples
def to_tuple(shape: Optional[Union[None, int, np.ndarray]]) -> Tuple:
"""Convert integers, arrays, and Nones to tuples.

Parameters
----------
shape: None, int or array-like
Represents the shape to convert to tuple.
shape : None, int, or array-like
Represents the shape to convert to a tuple.

Returns
-------
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
returned. If it is array-like, tuple(shape) is returned.
tuple
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
returned. If it is array-like, `tuple(shape)` is returned.

Examples
--------
>>> to_tuple(None)
()
>>> to_tuple(5)
(5,)
>>> to_tuple([1, 2, 3])
(1, 2, 3)
"""
if shape is None:
return tuple()
Expand All @@ -87,42 +97,47 @@ def _check_shape_type(shape):
return tuple(out)


def broadcast_dist_samples_shape(shapes, size=None):
"""Apply shape broadcasting to shape tuples but assuming that the shapes
correspond to draws from random variables, with the `size` tuple possibly
prepended to it. The `size` prepend is ignored to consider if the supplied
`shapes` can broadcast or not. It is prepended to the resulting broadcasted
`shapes`, if any of the shape tuples had the `size` prepend.
def broadcast_dist_samples_shape(
shapes: Iterable[Tuple[int, ...]], size: Optional[int] = None
) -> Tuple[int, ...]:
"""Apply shape broadcasting to shape tuples for random variables.

Parameters
----------
shapes: Iterable of tuples holding the distribution samples shapes
size: None, int or tuple (optional)
size of the sample set requested.
shapes : Iterable of tuples
Tuples holding the distribution samples shapes.
size : None, int, or tuple, optional
Size of the sample set requested.

Returns
-------
tuple of the resulting shape
tuple
The resulting broadcasted shape.

Examples
--------
.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (size, 5)
shape2 = (size, 4, 5)
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
size=size)
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (5,)
shape2 = (4, 5)
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
size=size)
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (1,)
shape1 = (5,)
Expand Down Expand Up @@ -291,7 +306,18 @@ def find_size(


def rv_size_is_none(size: Variable) -> bool:
"""Check whether an rv size is None (ie., pt.Constant([]))"""
"""Check whether the size of a random variable is None.

Parameters
----------
size : Variable
The size variable to check.

Returns
-------
bool
True if the size is None (i.e., pt.Constant([])), False otherwise.
"""
return size.type.shape == (0,) # type: ignore [attr-defined]


Expand All @@ -311,19 +337,21 @@ def change_dist_size(

Parameters
----------
dist:
dist : TensorVariable
The old distribution to be resized.
new_size:
new_size : Union[int, Tuple[int, ...]]
The new size of the distribution.
expand: bool, optional
If True, `new_size` is prepended to the existing distribution `size`, so that
the final size is equal to (*new_size, *dist.size). Defaults to false.
expand : bool, optional
If True, `new_size` is prepended to the existing distribution `size`,
so that the final size is equal to (*new_size, *dist.size).
Defaults to False.

Returns
-------
A new distribution variable that is equivalent to the original distribution with
the new size. The new distribution will not reuse the old RandomState/Generator
input, so it will be independent from the original distribution.
TensorVariable
A new distribution variable equivalent to the original distribution
with the new size. The new distribution will not reuse the old
RandomState/Generator input, making it independent from the original.

Examples
--------
Expand Down