Skip to content

Commit 9ad29b1

Browse files
committed
Don't implement default _supp_shape_from_params.
The errors raised by the default when it fails are rather cryptic Also fix bug in helper function
1 parent bde1bbc commit 9ad29b1

File tree

5 files changed

+139
-100
lines changed

5 files changed

+139
-100
lines changed

pytensor/tensor/random/basic.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import scipy.stats as stats
66

77
import pytensor
8-
from pytensor.tensor.basic import as_tensor_variable
9-
from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params
8+
from pytensor.tensor.basic import as_tensor_variable, arange
9+
from pytensor.tensor.random.op import RandomVariable
1010
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
11-
from pytensor.tensor.random.utils import broadcast_params
11+
from pytensor.tensor.random.utils import (
12+
broadcast_params,
13+
supp_shape_from_ref_param_shape,
14+
)
1215
from pytensor.tensor.random.var import (
1316
RandomGeneratorSharedVariable,
1417
RandomStateSharedVariable,
@@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable):
855858
dtype = "floatX"
856859
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
857860

861+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
862+
return supp_shape_from_ref_param_shape(
863+
ndim_supp=self.ndim_supp,
864+
dist_params=dist_params,
865+
param_shapes=param_shapes,
866+
ref_param_idx=0,
867+
)
868+
858869
def __call__(self, mean=None, cov=None, size=None, **kwargs):
859870
r""" "Draw samples from a multivariate normal distribution.
860871
@@ -933,6 +944,14 @@ class DirichletRV(RandomVariable):
933944
dtype = "floatX"
934945
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}")
935946

947+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
948+
return supp_shape_from_ref_param_shape(
949+
ndim_supp=self.ndim_supp,
950+
dist_params=dist_params,
951+
param_shapes=param_shapes,
952+
ref_param_idx=0,
953+
)
954+
936955
def __call__(self, alphas, size=None, **kwargs):
937956
r"""Draw samples from a dirichlet distribution.
938957
@@ -1776,9 +1795,12 @@ def __call__(self, n, p, size=None, **kwargs):
17761795
"""
17771796
return super().__call__(n, p, size=size, **kwargs)
17781797

1779-
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
1780-
return default_supp_shape_from_params(
1781-
self.ndim_supp, dist_params, rep_param_idx, param_shapes
1798+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
1799+
return supp_shape_from_ref_param_shape(
1800+
ndim_supp=self.ndim_supp,
1801+
dist_params=dist_params,
1802+
param_shapes=param_shapes,
1803+
ref_param_idx=1,
17821804
)
17831805

17841806
@classmethod

pytensor/tensor/random/op.py

Lines changed: 18 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,64 +24,6 @@
2424
from pytensor.tensor.var import TensorVariable
2525

2626

27-
def default_supp_shape_from_params(
28-
ndim_supp: int,
29-
dist_params: Sequence[Variable],
30-
rep_param_idx: int = 0,
31-
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
32-
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
33-
"""Infer the dimensions for the output of a `RandomVariable`.
34-
35-
This is a function that derives a random variable's support
36-
shape/dimensions from one of its parameters.
37-
38-
XXX: It's not always possible to determine a random variable's support
39-
shape from its parameters, so this function has fundamentally limited
40-
applicability and must be replaced by custom logic in such cases.
41-
42-
XXX: This function is not expected to handle `ndim_supp = 0` (i.e.
43-
scalars), since that is already definitively handled in the `Op` that
44-
calls this.
45-
46-
TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`.
47-
48-
Parameters
49-
----------
50-
ndim_supp: int
51-
Total number of dimensions for a single draw of the random variable
52-
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
53-
dist_params: list of `pytensor.graph.basic.Variable`
54-
The distribution parameters.
55-
rep_param_idx: int (optional)
56-
The index of the distribution parameter to use as a reference
57-
In other words, a parameter in `dist_param` with a shape corresponding
58-
to the support's shape.
59-
The default is the first parameter (i.e. the value 0).
60-
param_shapes: list of tuple of `ScalarVariable` (optional)
61-
Symbolic shapes for each distribution parameter. These will
62-
be used in place of distribution parameter-generated shapes.
63-
64-
Results
65-
-------
66-
out: a tuple representing the support shape for a distribution with the
67-
given `dist_params`.
68-
69-
"""
70-
if ndim_supp <= 0:
71-
raise ValueError("ndim_supp must be greater than 0")
72-
if param_shapes is not None:
73-
ref_param = param_shapes[rep_param_idx]
74-
return (ref_param[-ndim_supp],)
75-
else:
76-
ref_param = dist_params[rep_param_idx]
77-
if ref_param.ndim < ndim_supp:
78-
raise ValueError(
79-
"Reference parameter does not match the "
80-
f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)."
81-
)
82-
return ref_param.shape[-ndim_supp:]
83-
84-
8527
class RandomVariable(Op):
8628
"""An `Op` that produces a sample from a random variable.
8729
@@ -151,15 +93,29 @@ def __init__(
15193
if self.inplace:
15294
self.destroy_map = {0: [0]}
15395

154-
def _supp_shape_from_params(self, dist_params, **kwargs):
155-
"""Determine the support shape of a `RandomVariable`'s output given its parameters.
96+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
97+
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
15698
15799
This does *not* consider the extra dimensions added by the `size` parameter
158100
or independent (batched) parameters.
159101
160-
Defaults to `param_supp_shape_fn`.
102+
When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`,
103+
as it will avoid redundancies in PyTensor shape inference.
104+
105+
Examples
106+
--------
107+
Common multivariate `RandomVariable`s derive their support shapes implicitly from the
108+
last dimension of some of their parameters. For example `multivariate_normal` support shape
109+
corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`.
110+
For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used.
111+
112+
Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the
113+
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
114+
might have `support_shape=(steps,)`.
161115
"""
162-
return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)
116+
raise NotImplementedError(
117+
"`_supp_shape_from_params` must be implemented for multivariate RVs"
118+
)
163119

164120
def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]:
165121
"""Sample a numeric random variate."""

pytensor/tensor/random/utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from collections.abc import Sequence
21
from functools import wraps
32
from itertools import zip_longest
43
from types import ModuleType
5-
from typing import TYPE_CHECKING, Literal, Optional, Union
4+
from typing import TYPE_CHECKING, Literal, Optional, Sequence, Tuple, Union
65

76
import numpy as np
87

98
from pytensor.compile.sharedvalue import shared
109
from pytensor.graph.basic import Constant, Variable
10+
from pytensor.scalar import ScalarVariable
1111
from pytensor.tensor import get_vector_length
1212
from pytensor.tensor.basic import as_tensor_variable, cast, constant
1313
from pytensor.tensor.extra_ops import broadcast_to
@@ -285,3 +285,50 @@ def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:
285285
rng.default_update = new_rng
286286

287287
return out
288+
289+
290+
def supp_shape_from_ref_param_shape(
291+
*,
292+
ndim_supp: int,
293+
dist_params: Sequence[Variable],
294+
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
295+
ref_param_idx: int,
296+
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
297+
"""Extract the support shape of a multivariate `RandomVariable` from the shape of a reference parameter.
298+
299+
Several multivariate `RandomVariable`s have a support shape determined by the last dimensions of a parameter.
300+
For example `multivariate_normal(zeros(5, 3), eye(3)) has a support shape of (3,) that is determined by the
301+
last dimension of the mean or the covariance.
302+
303+
Parameters
304+
----------
305+
ndim_supp: int
306+
Support dimensionality of the `RandomVariable`.
307+
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
308+
dist_params: list of `pytensor.graph.basic.Variable`
309+
The distribution parameters.
310+
param_shapes: list of tuple of `ScalarVariable` (optional)
311+
Symbolic shapes for each distribution parameter. These will
312+
be used in place of distribution parameter-generated shapes.
313+
ref_param_idx: int
314+
The index of the distribution parameter to use as a reference
315+
316+
Returns
317+
-------
318+
out: tuple
319+
Representing the support shape for a `RandomVariable` with the given `dist_params`.
320+
321+
"""
322+
if ndim_supp <= 0:
323+
raise ValueError("ndim_supp must be greater than 0")
324+
if param_shapes is not None:
325+
ref_param = param_shapes[ref_param_idx]
326+
return (ref_param[-ndim_supp],)
327+
else:
328+
ref_param = dist_params[ref_param_idx]
329+
if ref_param.ndim < ndim_supp:
330+
raise ValueError(
331+
"Reference parameter does not match the expected dimensions; "
332+
f"{ref_param} has less than {ndim_supp} dim(s)."
333+
)
334+
return ref_param.shape[-ndim_supp:]

tests/tensor/random/test_op.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66
from pytensor.gradient import NullTypeGradError, grad
77
from pytensor.raise_op import Assert
88
from pytensor.tensor.math import eq
9-
from pytensor.tensor.random.op import (
10-
RandomState,
11-
RandomVariable,
12-
default_rng,
13-
default_supp_shape_from_params,
14-
)
9+
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
1510
from pytensor.tensor.shape import specify_shape
1611
from pytensor.tensor.type import all_dtypes, iscalar, tensor
1712

@@ -22,29 +17,6 @@ def set_pytensor_flags():
2217
yield
2318

2419

25-
def test_default_supp_shape_from_params():
26-
with pytest.raises(ValueError, match="^ndim_supp*"):
27-
default_supp_shape_from_params(0, (np.array([1, 2]), 0))
28-
29-
res = default_supp_shape_from_params(
30-
1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0
31-
)
32-
assert res == (2,)
33-
34-
res = default_supp_shape_from_params(
35-
1, (np.array([1, 2]), 0), param_shapes=((2,), ())
36-
)
37-
assert res == (2,)
38-
39-
with pytest.raises(ValueError, match="^Reference parameter*"):
40-
default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0)
41-
42-
res = default_supp_shape_from_params(
43-
2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1
44-
)
45-
assert res == (3, 4)
46-
47-
4820
def test_RandomVariable_basics():
4921
str_res = str(
5022
RandomVariable(

tests/tensor/random/test_utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pytensor import config, function
55
from pytensor.compile.mode import Mode
66
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
7-
from pytensor.tensor.random.utils import RandomStream, broadcast_params
7+
from pytensor.tensor.random.utils import (
8+
RandomStream,
9+
broadcast_params,
10+
supp_shape_from_ref_param_shape,
11+
)
812
from pytensor.tensor.type import matrix, tensor
913
from tests import unittest_tools as utt
1014

@@ -271,3 +275,41 @@ def __init__(self, seed=123):
271275
su2[0].set_value(su1[0].get_value())
272276

273277
np.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
278+
279+
280+
def test_supp_shape_from_ref_param_shape():
281+
with pytest.raises(ValueError, match="^ndim_supp*"):
282+
supp_shape_from_ref_param_shape(
283+
ndim_supp=0,
284+
dist_params=(np.array([1, 2]), 0),
285+
ref_param_idx=0,
286+
)
287+
288+
res = supp_shape_from_ref_param_shape(
289+
ndim_supp=1,
290+
dist_params=(np.array([1, 2]), np.eye(2)),
291+
ref_param_idx=0,
292+
)
293+
assert res == (2,)
294+
295+
res = supp_shape_from_ref_param_shape(
296+
ndim_supp=1,
297+
dist_params=(np.array([1, 2]), 0),
298+
param_shapes=((2,), ()),
299+
ref_param_idx=0,
300+
)
301+
assert res == (2,)
302+
303+
with pytest.raises(ValueError, match="^Reference parameter*"):
304+
supp_shape_from_ref_param_shape(
305+
ndim_supp=1,
306+
dist_params=(np.array(1),),
307+
ref_param_idx=0,
308+
)
309+
310+
res = supp_shape_from_ref_param_shape(
311+
ndim_supp=2,
312+
dist_params=(np.array([1, 2]), np.ones((2, 3, 4))),
313+
ref_param_idx=1,
314+
)
315+
assert res == (3, 4)

0 commit comments

Comments
 (0)