From bde1bbc74b53505b9d0f790a8d4e7ee25d184f44 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Jul 2023 09:33:40 +0200 Subject: [PATCH 1/5] Simplify RandomVariable._infer_shape --- pytensor/tensor/random/op.py | 45 +++++++++++++++++------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 56210538c0..f3fa777a7b 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -191,6 +191,8 @@ def _infer_shape( """ + from pytensor.tensor.extra_ops import broadcast_shape_iter + size_len = get_vector_length(size) if size_len > 0: @@ -216,57 +218,52 @@ def _infer_shape( # Broadcast the parameters param_shapes = params_broadcast_shapes( - param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params + param_shapes or [shape_tuple(p) for p in dist_params], + self.ndims_params, ) - def slice_ind_dims(p, ps, n): + def extract_batch_shape(p, ps, n): shape = tuple(ps) if n == 0: - return (p, shape) + return shape - ind_slice = (slice(None),) * (p.ndim - n) + (0,) * n - ind_shape = [ + batch_shape = [ s if b is False else constant(1, "int64") - for s, b in zip(shape[:-n], p.broadcastable[:-n]) + for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) ] - return ( - p[ind_slice], - ind_shape, - ) + return batch_shape # These are versions of our actual parameters with the anticipated # dimensions (i.e. support dimensions) removed so that only the # independent variate dimensions are left. - params_ind_slice = tuple( - slice_ind_dims(p, ps, n) + params_batch_shape = tuple( + extract_batch_shape(p, ps, n) for p, ps, n in zip(dist_params, param_shapes, self.ndims_params) ) - if len(params_ind_slice) == 1: - _, shape_ind = params_ind_slice[0] - elif len(params_ind_slice) > 1: + if len(params_batch_shape) == 1: + [batch_shape] = params_batch_shape + elif len(params_batch_shape) > 1: # If there are multiple parameters, the dimensions of their # independent variates should broadcast together. - p_slices, p_shapes = zip(*params_ind_slice) - - shape_ind = pytensor.tensor.extra_ops.broadcast_shape_iter( - p_shapes, arrays_are_shapes=True + batch_shape = broadcast_shape_iter( + params_batch_shape, + arrays_are_shapes=True, ) - else: # Distribution has no parameters - shape_ind = () + batch_shape = () if self.ndim_supp == 0: - shape_supp = () + supp_shape = () else: - shape_supp = self._supp_shape_from_params( + supp_shape = self._supp_shape_from_params( dist_params, param_shapes=param_shapes, ) - shape = tuple(shape_ind) + tuple(shape_supp) + shape = tuple(batch_shape) + tuple(supp_shape) if not shape: shape = constant([], dtype="int64") From 9ad29b1fd1baf1d0c4f89e85ff63c89dad4a8ec4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Jul 2023 10:19:37 +0200 Subject: [PATCH 2/5] Don't implement default _supp_shape_from_params. The errors raised by the default when it fails are rather cryptic Also fix bug in helper function --- pytensor/tensor/random/basic.py | 34 ++++++++++--- pytensor/tensor/random/op.py | 80 +++++++------------------------ pytensor/tensor/random/utils.py | 51 +++++++++++++++++++- tests/tensor/random/test_op.py | 30 +----------- tests/tensor/random/test_utils.py | 44 ++++++++++++++++- 5 files changed, 139 insertions(+), 100 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index cced395471..95fc01af7f 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -5,10 +5,13 @@ import scipy.stats as stats import pytensor -from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params +from pytensor.tensor.basic import as_tensor_variable, arange +from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType -from pytensor.tensor.random.utils import broadcast_params +from pytensor.tensor.random.utils import ( + broadcast_params, + supp_shape_from_ref_param_shape, +) from pytensor.tensor.random.var import ( RandomGeneratorSharedVariable, RandomStateSharedVariable, @@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable): dtype = "floatX" _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=0, + ) + def __call__(self, mean=None, cov=None, size=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. @@ -933,6 +944,14 @@ class DirichletRV(RandomVariable): dtype = "floatX" _print_name = ("Dirichlet", "\\operatorname{Dirichlet}") + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=0, + ) + def __call__(self, alphas, size=None, **kwargs): r"""Draw samples from a dirichlet distribution. @@ -1776,9 +1795,12 @@ def __call__(self, n, p, size=None, **kwargs): """ return super().__call__(n, p, size=size, **kwargs) - def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): - return default_supp_shape_from_params( - self.ndim_supp, dist_params, rep_param_idx, param_shapes + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=1, ) @classmethod diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index f3fa777a7b..a8c47d5ee8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -24,64 +24,6 @@ from pytensor.tensor.var import TensorVariable -def default_supp_shape_from_params( - ndim_supp: int, - dist_params: Sequence[Variable], - rep_param_idx: int = 0, - param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None, -) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]: - """Infer the dimensions for the output of a `RandomVariable`. - - This is a function that derives a random variable's support - shape/dimensions from one of its parameters. - - XXX: It's not always possible to determine a random variable's support - shape from its parameters, so this function has fundamentally limited - applicability and must be replaced by custom logic in such cases. - - XXX: This function is not expected to handle `ndim_supp = 0` (i.e. - scalars), since that is already definitively handled in the `Op` that - calls this. - - TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`. - - Parameters - ---------- - ndim_supp: int - Total number of dimensions for a single draw of the random variable - (e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`). - dist_params: list of `pytensor.graph.basic.Variable` - The distribution parameters. - rep_param_idx: int (optional) - The index of the distribution parameter to use as a reference - In other words, a parameter in `dist_param` with a shape corresponding - to the support's shape. - The default is the first parameter (i.e. the value 0). - param_shapes: list of tuple of `ScalarVariable` (optional) - Symbolic shapes for each distribution parameter. These will - be used in place of distribution parameter-generated shapes. - - Results - ------- - out: a tuple representing the support shape for a distribution with the - given `dist_params`. - - """ - if ndim_supp <= 0: - raise ValueError("ndim_supp must be greater than 0") - if param_shapes is not None: - ref_param = param_shapes[rep_param_idx] - return (ref_param[-ndim_supp],) - else: - ref_param = dist_params[rep_param_idx] - if ref_param.ndim < ndim_supp: - raise ValueError( - "Reference parameter does not match the " - f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)." - ) - return ref_param.shape[-ndim_supp:] - - class RandomVariable(Op): """An `Op` that produces a sample from a random variable. @@ -151,15 +93,29 @@ def __init__( if self.inplace: self.destroy_map = {0: [0]} - def _supp_shape_from_params(self, dist_params, **kwargs): - """Determine the support shape of a `RandomVariable`'s output given its parameters. + def _supp_shape_from_params(self, dist_params, param_shapes=None): + """Determine the support shape of a multivariate `RandomVariable`'s output given its parameters. This does *not* consider the extra dimensions added by the `size` parameter or independent (batched) parameters. - Defaults to `param_supp_shape_fn`. + When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`, + as it will avoid redundancies in PyTensor shape inference. + + Examples + -------- + Common multivariate `RandomVariable`s derive their support shapes implicitly from the + last dimension of some of their parameters. For example `multivariate_normal` support shape + corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`. + For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used. + + Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the + values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`, + might have `support_shape=(steps,)`. """ - return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs) + raise NotImplementedError( + "`_supp_shape_from_params` must be implemented for multivariate RVs" + ) def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]: """Sample a numeric random variate.""" diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 2c2485e173..a7bf2b699c 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -1,13 +1,13 @@ -from collections.abc import Sequence from functools import wraps from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Sequence, Tuple, Union import numpy as np from pytensor.compile.sharedvalue import shared from pytensor.graph.basic import Constant, Variable +from pytensor.scalar import ScalarVariable from pytensor.tensor import get_vector_length from pytensor.tensor.basic import as_tensor_variable, cast, constant from pytensor.tensor.extra_ops import broadcast_to @@ -285,3 +285,50 @@ def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable: rng.default_update = new_rng return out + + +def supp_shape_from_ref_param_shape( + *, + ndim_supp: int, + dist_params: Sequence[Variable], + param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None, + ref_param_idx: int, +) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]: + """Extract the support shape of a multivariate `RandomVariable` from the shape of a reference parameter. + + Several multivariate `RandomVariable`s have a support shape determined by the last dimensions of a parameter. + For example `multivariate_normal(zeros(5, 3), eye(3)) has a support shape of (3,) that is determined by the + last dimension of the mean or the covariance. + + Parameters + ---------- + ndim_supp: int + Support dimensionality of the `RandomVariable`. + (e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`). + dist_params: list of `pytensor.graph.basic.Variable` + The distribution parameters. + param_shapes: list of tuple of `ScalarVariable` (optional) + Symbolic shapes for each distribution parameter. These will + be used in place of distribution parameter-generated shapes. + ref_param_idx: int + The index of the distribution parameter to use as a reference + + Returns + ------- + out: tuple + Representing the support shape for a `RandomVariable` with the given `dist_params`. + + """ + if ndim_supp <= 0: + raise ValueError("ndim_supp must be greater than 0") + if param_shapes is not None: + ref_param = param_shapes[ref_param_idx] + return (ref_param[-ndim_supp],) + else: + ref_param = dist_params[ref_param_idx] + if ref_param.ndim < ndim_supp: + raise ValueError( + "Reference parameter does not match the expected dimensions; " + f"{ref_param} has less than {ndim_supp} dim(s)." + ) + return ref_param.shape[-ndim_supp:] diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index a2d55908e5..0eec50e5a6 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -6,12 +6,7 @@ from pytensor.gradient import NullTypeGradError, grad from pytensor.raise_op import Assert from pytensor.tensor.math import eq -from pytensor.tensor.random.op import ( - RandomState, - RandomVariable, - default_rng, - default_supp_shape_from_params, -) +from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import all_dtypes, iscalar, tensor @@ -22,29 +17,6 @@ def set_pytensor_flags(): yield -def test_default_supp_shape_from_params(): - with pytest.raises(ValueError, match="^ndim_supp*"): - default_supp_shape_from_params(0, (np.array([1, 2]), 0)) - - res = default_supp_shape_from_params( - 1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0 - ) - assert res == (2,) - - res = default_supp_shape_from_params( - 1, (np.array([1, 2]), 0), param_shapes=((2,), ()) - ) - assert res == (2,) - - with pytest.raises(ValueError, match="^Reference parameter*"): - default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0) - - res = default_supp_shape_from_params( - 2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1 - ) - assert res == (3, 4) - - def test_RandomVariable_basics(): str_res = str( RandomVariable( diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 989a034ade..fd8f74c875 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -4,7 +4,11 @@ from pytensor import config, function from pytensor.compile.mode import Mode from pytensor.graph.rewriting.db import RewriteDatabaseQuery -from pytensor.tensor.random.utils import RandomStream, broadcast_params +from pytensor.tensor.random.utils import ( + RandomStream, + broadcast_params, + supp_shape_from_ref_param_shape, +) from pytensor.tensor.type import matrix, tensor from tests import unittest_tools as utt @@ -271,3 +275,41 @@ def __init__(self, seed=123): su2[0].set_value(su1[0].get_value()) np.testing.assert_array_almost_equal(f1(), f2(), decimal=6) + + +def test_supp_shape_from_ref_param_shape(): + with pytest.raises(ValueError, match="^ndim_supp*"): + supp_shape_from_ref_param_shape( + ndim_supp=0, + dist_params=(np.array([1, 2]), 0), + ref_param_idx=0, + ) + + res = supp_shape_from_ref_param_shape( + ndim_supp=1, + dist_params=(np.array([1, 2]), np.eye(2)), + ref_param_idx=0, + ) + assert res == (2,) + + res = supp_shape_from_ref_param_shape( + ndim_supp=1, + dist_params=(np.array([1, 2]), 0), + param_shapes=((2,), ()), + ref_param_idx=0, + ) + assert res == (2,) + + with pytest.raises(ValueError, match="^Reference parameter*"): + supp_shape_from_ref_param_shape( + ndim_supp=1, + dist_params=(np.array(1),), + ref_param_idx=0, + ) + + res = supp_shape_from_ref_param_shape( + ndim_supp=2, + dist_params=(np.array([1, 2]), np.ones((2, 3, 4))), + ref_param_idx=1, + ) + assert res == (3, 4) From 66d85f4290d0d162191e530f651fcac0e91568f0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Jul 2023 10:34:37 +0200 Subject: [PATCH 3/5] Fix bug in supp_shape_from_ref_param_shape --- pytensor/tensor/random/utils.py | 2 +- tests/tensor/random/test_utils.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index a7bf2b699c..3728195751 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -323,7 +323,7 @@ def supp_shape_from_ref_param_shape( raise ValueError("ndim_supp must be greater than 0") if param_shapes is not None: ref_param = param_shapes[ref_param_idx] - return (ref_param[-ndim_supp],) + return tuple(ref_param[i] for i in range(-ndim_supp, 0)) else: ref_param = dist_params[ref_param_idx] if ref_param.ndim < ndim_supp: diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index fd8f74c875..a503878490 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -313,3 +313,11 @@ def test_supp_shape_from_ref_param_shape(): ref_param_idx=1, ) assert res == (3, 4) + + res = supp_shape_from_ref_param_shape( + ndim_supp=2, + dist_params=(np.array([1, 2]), np.ones((2, 3, 4))), + param_shapes=((2,), (2, 3, 4)), + ref_param_idx=1, + ) + assert res == (3, 4) From d0b19e2fd9d9947a2b207ae1f2c901112028d76a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Jul 2023 10:22:36 +0200 Subject: [PATCH 4/5] Fix PermutationRV ambiguous signature The RV always expects a vector input and `ndims_paramas` is always `[1]`. Size is no longer ignored --- pytensor/tensor/random/basic.py | 47 +++++++++++++++++++++---------- tests/tensor/random/test_basic.py | 8 ++++++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 95fc01af7f..96c7913336 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -5,7 +5,7 @@ import scipy.stats as stats import pytensor -from pytensor.tensor.basic import as_tensor_variable, arange +from pytensor.tensor.basic import arange, as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType from pytensor.tensor.random.utils import ( @@ -2072,18 +2072,15 @@ class PermutationRV(RandomVariable): @classmethod def rng_fn(cls, rng, x, size): - return rng.permutation(x if x.ndim > 0 else x.item()) + return rng.permutation(x) - def _infer_shape(self, size, dist_params, param_shapes=None): - param_shapes = param_shapes or [p.shape for p in dist_params] - - (x,) = dist_params - (x_shape,) = param_shapes - - if x.ndim == 0: - return (x,) - else: - return x_shape + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=0, + ) def __call__(self, x, **kwargs): r"""Randomly permute a sequence or a range of values. @@ -2096,15 +2093,35 @@ def __call__(self, x, **kwargs): Parameters ---------- x - If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence, - shuffle its elements randomly. + Elements to be shuffled. """ x = as_tensor_variable(x) return super().__call__(x, dtype=x.dtype, **kwargs) -permutation = PermutationRV() +_permutation = PermutationRV() + + +def permutation(x, **kwargs): + r"""Randomly permute a sequence or a range of values. + + Signature + --------- + + `(x) -> (x)` + + Parameters + ---------- + x + If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence, + shuffle its elements randomly. + + """ + x = as_tensor_variable(x) + if x.type.ndim == 0: + x = arange(x) + return _permutation(x, **kwargs) __all__ = [ diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index f68357a98a..4032b9a673 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1413,6 +1413,14 @@ def test_permutation_samples(): compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX)) +def test_permutation_shape(): + assert tuple(permutation(5).shape.eval()) == (5,) + assert tuple(permutation(np.arange(5)).shape.eval()) == (5,) + assert tuple(permutation(np.arange(10).reshape(2, 5)).shape.eval()) == (2, 5) + assert tuple(permutation(5, size=(2, 3)).shape.eval()) == (2, 3, 5) + assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5) + + @config.change_flags(compute_test_value="off") def test_pickle(): # This is an interesting `Op` case, because it has `None` types and a From c6895496e5f7b34aad4d5660cb5a719283662c12 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Jul 2023 14:34:46 +0200 Subject: [PATCH 5/5] Seed test_generic_solve_to_solve_triangular --- tests/tensor/rewriting/test_linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c28388d11c..04cece3ff4 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -94,7 +94,8 @@ def test_generic_solve_to_solve_triangular(): b2 = solve(U, x) f = pytensor.function([A, x], b1) - X = np.random.normal(size=(10, 10)).astype(config.floatX) + rng = np.random.default_rng(97) + X = rng.normal(size=(10, 10)).astype(config.floatX) X = X @ X.T X_chol = np.linalg.cholesky(X) eye = np.eye(10, dtype=config.floatX)