Skip to content

Commit 9978e9e

Browse files
committed
Add missing functions in shape_utils doc, improve return formatting, and include type hints in docstring #7004
1 parent d7415de commit 9978e9e

File tree

2 files changed

+53
-26
lines changed

2 files changed

+53
-26
lines changed

docs/source/api/shape_utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ This module introduces functions that are made aware of the requested `size_tupl
1616
to_tuple
1717
rv_size_is_none
1818
change_dist_size
19+
broadcast_dist_samples_shape

pymc/distributions/shape_utils.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import warnings
2121

2222
from functools import singledispatch
23-
from typing import Any, Optional, Sequence, Tuple, Union, cast
23+
from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast
2424

2525
import numpy as np
2626

@@ -49,18 +49,28 @@
4949
from pymc.util import _add_future_warning_tag
5050

5151

52-
def to_tuple(shape):
53-
"""Convert ints, arrays, and Nones to tuples
52+
def to_tuple(shape: Optional[Union[None, int, np.ndarray]]) -> Tuple:
53+
"""Convert integers, arrays, and Nones to tuples.
5454
5555
Parameters
5656
----------
57-
shape: None, int or array-like
58-
Represents the shape to convert to tuple.
57+
shape : None, int, or array-like
58+
Represents the shape to convert to a tuple.
5959
6060
Returns
6161
-------
62-
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
63-
returned. If it is array-like, tuple(shape) is returned.
62+
tuple
63+
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
64+
returned. If it is array-like, `tuple(shape)` is returned.
65+
66+
Examples
67+
--------
68+
>>> to_tuple(None)
69+
()
70+
>>> to_tuple(5)
71+
(5,)
72+
>>> to_tuple([1, 2, 3])
73+
(1, 2, 3)
6474
"""
6575
if shape is None:
6676
return tuple()
@@ -87,42 +97,45 @@ def _check_shape_type(shape):
8797
return tuple(out)
8898

8999

90-
def broadcast_dist_samples_shape(shapes, size=None):
91-
"""Apply shape broadcasting to shape tuples but assuming that the shapes
92-
correspond to draws from random variables, with the `size` tuple possibly
93-
prepended to it. The `size` prepend is ignored to consider if the supplied
94-
`shapes` can broadcast or not. It is prepended to the resulting broadcasted
95-
`shapes`, if any of the shape tuples had the `size` prepend.
100+
def broadcast_dist_samples_shape(shapes: Iterable[Tuple[int, ...]], size: Optional[int] = None) -> Tuple[int, ...]:
101+
"""Apply shape broadcasting to shape tuples for random variables.
96102
97103
Parameters
98104
----------
99-
shapes: Iterable of tuples holding the distribution samples shapes
100-
size: None, int or tuple (optional)
101-
size of the sample set requested.
105+
shapes : Iterable of tuples
106+
Tuples holding the distribution samples shapes.
107+
size : None, int, or tuple, optional
108+
Size of the sample set requested.
102109
103110
Returns
104111
-------
105-
tuple of the resulting shape
112+
tuple
113+
The resulting broadcasted shape.
106114
107115
Examples
108116
--------
109117
.. code-block:: python
118+
110119
size = 100
111120
shape0 = (size,)
112121
shape1 = (size, 5)
113122
shape2 = (size, 4, 5)
114123
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
115124
size=size)
116125
assert out == (size, 4, 5)
126+
117127
.. code-block:: python
128+
118129
size = 100
119130
shape0 = (size,)
120131
shape1 = (5,)
121132
shape2 = (4, 5)
122133
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
123134
size=size)
124135
assert out == (size, 4, 5)
136+
125137
.. code-block:: python
138+
126139
size = 100
127140
shape0 = (1,)
128141
shape1 = (5,)
@@ -291,7 +304,18 @@ def find_size(
291304

292305

293306
def rv_size_is_none(size: Variable) -> bool:
294-
"""Check whether an rv size is None (ie., pt.Constant([]))"""
307+
"""Check whether the size of a random variable is None.
308+
309+
Parameters
310+
----------
311+
size : Variable
312+
The size variable to check.
313+
314+
Returns
315+
-------
316+
bool
317+
True if the size is None (i.e., pt.Constant([])), False otherwise.
318+
"""
295319
return size.type.shape == (0,) # type: ignore [attr-defined]
296320

297321

@@ -311,19 +335,21 @@ def change_dist_size(
311335
312336
Parameters
313337
----------
314-
dist:
338+
dist : TensorVariable
315339
The old distribution to be resized.
316-
new_size:
340+
new_size : Union[int, Tuple[int, ...]]
317341
The new size of the distribution.
318-
expand: bool, optional
319-
If True, `new_size` is prepended to the existing distribution `size`, so that
320-
the final size is equal to (*new_size, *dist.size). Defaults to false.
342+
expand : bool, optional
343+
If True, `new_size` is prepended to the existing distribution `size`,
344+
so that the final size is equal to (*new_size, *dist.size).
345+
Defaults to False.
321346
322347
Returns
323348
-------
324-
A new distribution variable that is equivalent to the original distribution with
325-
the new size. The new distribution will not reuse the old RandomState/Generator
326-
input, so it will be independent from the original distribution.
349+
TensorVariable
350+
A new distribution variable equivalent to the original distribution
351+
with the new size. The new distribution will not reuse the old
352+
RandomState/Generator input, making it independent from the original.
327353
328354
Examples
329355
--------

0 commit comments

Comments
 (0)