From 773cea20687f5e25c2bc525fd03f01984db3c771 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 09:49:21 +0200 Subject: [PATCH 01/15] Rename file elemwise_codegen to vectorize_codegen --- pytensor/link/numba/dispatch/elemwise.py | 8 ++++---- .../{elemwise_codegen.py => vectorize_codegen.py} | 0 2 files changed, 4 insertions(+), 4 deletions(-) rename pytensor/link/numba/dispatch/{elemwise_codegen.py => vectorize_codegen.py} (100%) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 5710414116..aae00675e4 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -18,7 +18,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch import elemwise_codegen +from pytensor.link.numba.dispatch import vectorize_codegen from pytensor.link.numba.dispatch.basic import ( create_numba_signature, create_tuple_creator, @@ -558,14 +558,14 @@ def codegen( ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] - iter_shape = elemwise_codegen.compute_itershape( + iter_shape = vectorize_codegen.compute_itershape( ctx, builder, in_shapes, input_bc_patterns_val, ) - outputs, output_types = elemwise_codegen.make_outputs( + outputs, output_types = vectorize_codegen.make_outputs( ctx, builder, iter_shape, @@ -576,7 +576,7 @@ def codegen( input_types, ) - elemwise_codegen.make_loop_call( + vectorize_codegen.make_loop_call( typingctx, ctx, builder, diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py similarity index 100% rename from pytensor/link/numba/dispatch/elemwise_codegen.py rename to pytensor/link/numba/dispatch/vectorize_codegen.py From 77446546f0e36f5a071903111ae3753554a73bbc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 09:55:16 +0200 Subject: [PATCH 02/15] Move vectorize wrapper to vectorize_codegen --- pytensor/link/numba/dispatch/elemwise.py | 189 +----------------- .../link/numba/dispatch/vectorize_codegen.py | 165 ++++++++++++++- 2 files changed, 171 insertions(+), 183 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index aae00675e4..52bc581df6 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -8,17 +8,13 @@ import numba import numpy as np -from numba import TypingError, types -from numba.core import cgutils from numba.core.extending import overload -from numba.np import arrayobj from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch import vectorize_codegen from pytensor.link.numba.dispatch.basic import ( create_numba_signature, create_tuple_creator, @@ -26,6 +22,7 @@ numba_njit, use_optimized_cheap_pass, ) +from pytensor.link.numba.dispatch.vectorize_codegen import _jit_options, _vectorized from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( AND, @@ -463,167 +460,6 @@ def axis_apply_fn(x): return axis_apply_fn -_jit_options = { - "fastmath": { - "arcp", # Allow Reciprocal - "contract", # Allow floating-point contraction - "afn", # Approximate functions - "reassoc", - "nsz", # TODO Do we want this one? - }, - "no_cpython_wrapper": True, - "no_cfunc_wrapper": True, -} - - -@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) -def _vectorized( - typingctx, - scalar_func, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, - inputs, -): - arg_types = [ - scalar_func, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, - inputs, - ] - - if not isinstance(input_bc_patterns, types.Literal): - raise TypingError("input_bc_patterns must be literal.") - input_bc_patterns = input_bc_patterns.literal_value - input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) - - if not isinstance(output_bc_patterns, types.Literal): - raise TypeError("output_bc_patterns must be literal.") - output_bc_patterns = output_bc_patterns.literal_value - output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) - - if not isinstance(output_dtypes, types.Literal): - raise TypeError("output_dtypes must be literal.") - output_dtypes = output_dtypes.literal_value - output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) - - if not isinstance(inplace_pattern, types.Literal): - raise TypeError("inplace_pattern must be literal.") - inplace_pattern = inplace_pattern.literal_value - inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - - n_outputs = len(output_bc_patterns) - - if not len(inputs) > 0: - raise TypingError("Empty argument list to elemwise op.") - - if not n_outputs > 0: - raise TypingError("Empty list of outputs for elemwise op.") - - if not all(isinstance(input, types.Array) for input in inputs): - raise TypingError("Inputs to elemwise must be arrays.") - ndim = inputs[0].ndim - - if not all(input.ndim == ndim for input in inputs): - raise TypingError("Inputs to elemwise must have the same rank.") - - if not all(len(pattern) == ndim for pattern in output_bc_patterns): - raise TypingError("Invalid output broadcasting pattern.") - - scalar_signature = typingctx.resolve_function_type( - scalar_func, [in_type.dtype for in_type in inputs], {} - ) - - # So we can access the constant values in codegen... - input_bc_patterns_val = input_bc_patterns - output_bc_patterns_val = output_bc_patterns - output_dtypes_val = output_dtypes - inplace_pattern_val = inplace_pattern - input_types = inputs - - def codegen( - ctx, - builder, - sig, - args, - ): - [_, _, _, _, _, inputs] = args - inputs = cgutils.unpack_tuple(builder, inputs) - inputs = [ - arrayobj.make_array(ty)(ctx, builder, val) - for ty, val in zip(input_types, inputs) - ] - in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] - - iter_shape = vectorize_codegen.compute_itershape( - ctx, - builder, - in_shapes, - input_bc_patterns_val, - ) - - outputs, output_types = vectorize_codegen.make_outputs( - ctx, - builder, - iter_shape, - output_bc_patterns_val, - output_dtypes_val, - inplace_pattern_val, - inputs, - input_types, - ) - - vectorize_codegen.make_loop_call( - typingctx, - ctx, - builder, - scalar_func, - scalar_signature, - iter_shape, - inputs, - outputs, - input_bc_patterns_val, - output_bc_patterns_val, - input_types, - output_types, - ) - - if len(outputs) == 1: - if inplace_pattern: - assert inplace_pattern[0][0] == 0 - ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) - return outputs[0]._getvalue() - - for inplace_idx in dict(inplace_pattern): - ctx.nrt.incref( - builder, - sig.return_type.types[inplace_idx], - outputs[inplace_idx]._get_value(), - ) - return ctx.make_tuple( - builder, sig.return_type, [out._getvalue() for out in outputs] - ) - - ret_types = [ - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") - for dtype in output_dtypes - ] - - for output_idx, input_idx in inplace_pattern: - ret_types[output_idx] = input_types[input_idx] - - ret_type = types.Tuple(ret_types) - - if len(output_dtypes) == 1: - ret_type = ret_type.types[0] - sig = ret_type(*arg_types) - - return sig, codegen - - @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): # Creating a new scalar node is more involved and unnecessary @@ -634,16 +470,12 @@ def numba_funcify_Elemwise(op, node, **kwargs): scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) - flags = { - "arcp", # Allow Reciprocal - "contract", # Allow floating-point contraction - "afn", # Approximate functions - "reassoc", - "nsz", # TODO Do we want this one? - } - scalar_op_fn = numba_funcify( - op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs + op.scalar_op, + node=scalar_node, + parent_node=node, + fastmath=_jit_options["fastmath"], + **kwargs, ) ndim = node.outputs[0].ndim @@ -700,14 +532,7 @@ def elemwise(*inputs): return tuple(outputs_summed) return outputs_summed[0] - @overload( - elemwise, - jit_options={ - "fastmath": flags, - "no_cpython_wrapper": True, - "no_cfunc_wrapper": True, - }, - ) + @overload(elemwise, jit_options=_jit_options) def ov_elemwise(*inputs): return elemwise_wrapper diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 431d1e8ce1..2272646052 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -1,16 +1,179 @@ from __future__ import annotations +import base64 +import pickle from typing import Any import numba import numpy as np from llvmlite import ir -from numba import types +from numba import TypingError, types from numba.core import cgutils from numba.core.base import BaseContext from numba.np import arrayobj +_jit_options = { + "fastmath": { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # TODO Do we want this one? + }, + "no_cpython_wrapper": True, + "no_cfunc_wrapper": True, +} + + +@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) +def _vectorized( + typingctx, + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, +): + arg_types = [ + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, + ] + + if not isinstance(input_bc_patterns, types.Literal): + raise TypingError("input_bc_patterns must be literal.") + input_bc_patterns = input_bc_patterns.literal_value + input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) + + if not isinstance(output_bc_patterns, types.Literal): + raise TypeError("output_bc_patterns must be literal.") + output_bc_patterns = output_bc_patterns.literal_value + output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) + + if not isinstance(output_dtypes, types.Literal): + raise TypeError("output_dtypes must be literal.") + output_dtypes = output_dtypes.literal_value + output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) + + if not isinstance(inplace_pattern, types.Literal): + raise TypeError("inplace_pattern must be literal.") + inplace_pattern = inplace_pattern.literal_value + inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) + + n_outputs = len(output_bc_patterns) + + if not len(inputs) > 0: + raise TypingError("Empty argument list to elemwise op.") + + if not n_outputs > 0: + raise TypingError("Empty list of outputs for elemwise op.") + + if not all(isinstance(input, types.Array) for input in inputs): + raise TypingError("Inputs to elemwise must be arrays.") + ndim = inputs[0].ndim + + if not all(input.ndim == ndim for input in inputs): + raise TypingError("Inputs to elemwise must have the same rank.") + + if not all(len(pattern) == ndim for pattern in output_bc_patterns): + raise TypingError("Invalid output broadcasting pattern.") + + scalar_signature = typingctx.resolve_function_type( + scalar_func, [in_type.dtype for in_type in inputs], {} + ) + + # So we can access the constant values in codegen... + input_bc_patterns_val = input_bc_patterns + output_bc_patterns_val = output_bc_patterns + output_dtypes_val = output_dtypes + inplace_pattern_val = inplace_pattern + input_types = inputs + + def codegen( + ctx, + builder, + sig, + args, + ): + [_, _, _, _, _, inputs] = args + inputs = cgutils.unpack_tuple(builder, inputs) + inputs = [ + arrayobj.make_array(ty)(ctx, builder, val) + for ty, val in zip(input_types, inputs) + ] + in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] + + iter_shape = compute_itershape( + ctx, + builder, + in_shapes, + input_bc_patterns_val, + ) + + outputs, output_types = make_outputs( + ctx, + builder, + iter_shape, + output_bc_patterns_val, + output_dtypes_val, + inplace_pattern_val, + inputs, + input_types, + ) + + make_loop_call( + typingctx, + ctx, + builder, + scalar_func, + scalar_signature, + iter_shape, + inputs, + outputs, + input_bc_patterns_val, + output_bc_patterns_val, + input_types, + output_types, + ) + + if len(outputs) == 1: + if inplace_pattern: + assert inplace_pattern[0][0] == 0 + ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) + return outputs[0]._getvalue() + + for inplace_idx in dict(inplace_pattern): + ctx.nrt.incref( + builder, + sig.return_type.types[inplace_idx], + outputs[inplace_idx]._get_value(), + ) + return ctx.make_tuple( + builder, sig.return_type, [out._getvalue() for out in outputs] + ) + + ret_types = [ + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") + for dtype in output_dtypes + ] + + for output_idx, input_idx in inplace_pattern: + ret_types[output_idx] = input_types[input_idx] + + ret_type = types.Tuple(ret_types) + + if len(output_dtypes) == 1: + ret_type = ret_type.types[0] + sig = ret_type(*arg_types) + + return sig, codegen + + def compute_itershape( ctx: BaseContext, builder: ir.IRBuilder, From e8dd61181a46dfc8fc9315fb875089616b0e6825 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 10:00:38 +0200 Subject: [PATCH 03/15] Refactor vectorize literal encoding helper --- pytensor/link/numba/dispatch/elemwise.py | 25 +++++++++---------- .../link/numba/dispatch/vectorize_codegen.py | 5 ++++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 52bc581df6..d73e1bf73d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,5 +1,3 @@ -import base64 -import pickle from collections.abc import Callable from functools import singledispatch from numbers import Number @@ -22,7 +20,11 @@ numba_njit, use_optimized_cheap_pass, ) -from pytensor.link.numba.dispatch.vectorize_codegen import _jit_options, _vectorized +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _jit_options, + _vectorized, + encode_literals, +) from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( AND, @@ -478,19 +480,16 @@ def numba_funcify_Elemwise(op, node, **kwargs): **kwargs, ) - ndim = node.outputs[0].ndim - output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs]) - input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs]) - output_dtypes = tuple(variable.dtype for variable in node.outputs) + input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) + output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) + output_dtypes = tuple(out.type.dtype for out in node.outputs) inplace_pattern = tuple(op.inplace_pattern.items()) # numba doesn't support nested literals right now... - input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() - output_bc_patterns_enc = base64.encodebytes( - pickle.dumps(output_bc_patterns) - ).decode() - output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode() - inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() + input_bc_patterns_enc = encode_literals(input_bc_patterns) + output_bc_patterns_enc = encode_literals(output_bc_patterns) + output_dtypes_enc = encode_literals(output_dtypes) + inplace_pattern_enc = encode_literals(inplace_pattern) def elemwise_wrapper(*inputs): return _vectorized( diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 2272646052..14c846c4e4 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -2,6 +2,7 @@ import base64 import pickle +from collections.abc import Sequence from typing import Any import numba @@ -13,6 +14,10 @@ from numba.np import arrayobj +def encode_literals(literals: Sequence) -> str: + return base64.encodebytes(pickle.dumps(literals)).decode() + + _jit_options = { "fastmath": { "arcp", # Allow Reciprocal From ba10b315fe12c574360d22ca43b2b88c06202189 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 18:25:15 +0200 Subject: [PATCH 04/15] Refactor helper to create safe gufunc signature --- pytensor/tensor/blockwise.py | 25 ++++++------------------- pytensor/tensor/utils.py | 23 ++++++++++++++++++++++- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 9eac2d1d76..0511a4ce47 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,7 @@ from pytensor import config from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import ( @@ -22,27 +22,11 @@ _parse_gufunc_signature, broadcast_static_dim_lengths, import_func_from_string, + safe_signature, ) from pytensor.tensor.variable import TensorVariable -def safe_signature( - core_inputs: Sequence[Variable], - core_outputs: Sequence[Variable], -) -> str: - def operand_sig(operand: Variable, prefix: str) -> str: - operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) - return f"({operands})" - - inputs_sig = ",".join( - operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) - ) - outputs_sig = ",".join( - operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) - ) - return f"{inputs_sig}->{outputs_sig}" - - class Blockwise(Op): """Generalizes a core `Op` to work with batched dimensions. @@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: else: # TODO: This is pretty bad for shape inference and merge optimization! # Should get better as we add signatures to our Ops - signature = safe_signature(node.inputs, node.outputs) + signature = safe_signature( + [inp.type.ndim for inp in node.inputs], + [out.type.ndim for out in node.outputs], + ) return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 1f86ad5dbe..55aa7fc836 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -172,7 +172,11 @@ def broadcast_static_dim_lengths( _SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$" -def _parse_gufunc_signature(signature): +def _parse_gufunc_signature( + signature, +) -> tuple[ + list[tuple[str, ...]], ... +]: # mypy doesn't know it's alwayl a length two tuple """ Parse string signatures for a generalized universal function. @@ -198,3 +202,20 @@ def _parse_gufunc_signature(signature): ] for arg_list in signature.split("->") ) + + +def safe_signature( + core_inputs_ndim: Sequence[int], + core_outputs_ndim: Sequence[int], +) -> str: + def operand_sig(operand_ndim: int, prefix: str) -> str: + operands = ",".join(f"{prefix}{i}" for i in range(operand_ndim)) + return f"({operands})" + + inputs_sig = ",".join( + operand_sig(ndim, prefix=f"i{n}") for n, ndim in enumerate(core_inputs_ndim) + ) + outputs_sig = ",".join( + operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim) + ) + return f"{inputs_sig}->{outputs_sig}" From 6224aad83c23694403d8c22cdf9ca0aaa469ff4a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 18:28:59 +0200 Subject: [PATCH 05/15] StandandardNormalRV is now just a helper function --- pytensor/link/jax/dispatch/random.py | 1 - pytensor/tensor/random/basic.py | 40 +++++++++------------------- pytensor/tensor/random/utils.py | 24 ++++++++--------- tests/tensor/random/test_utils.py | 2 +- 4 files changed, 25 insertions(+), 42 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 824d728faf..d07091d099 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -162,7 +162,6 @@ def sample_fn(rng, size, dtype, *parameters): @jax_sample_fn.register(ptr.LaplaceRV) @jax_sample_fn.register(ptr.LogisticRV) @jax_sample_fn.register(ptr.NormalRV) -@jax_sample_fn.register(ptr.StandardNormalRV) def jax_sample_fn_loc_scale(op, node): """JAX implementation of random variables in the loc-scale families. diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 0f1dfa9e35..8bd20d809b 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -281,38 +281,24 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): normal = NormalRV() -class StandardNormalRV(NormalRV): - r"""A standard normal continuous random variable. +def standard_normal(*, size=None, rng=None, dtype=None): + """Draw samples from a standard normal distribution. - The probability density function for `standard_normal` is: + Signature + --------- - .. math:: + `nil -> ()` - f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}} + Parameters + ---------- + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. """ - - def __call__(self, size=None, **kwargs): - """Draw samples from a standard normal distribution. - - Signature - --------- - - `nil -> ()` - - Parameters - ---------- - size - Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` - independent, identically distributed random variables are - returned. Default is `None` in which case a single random variable - is returned. - - """ - return super().__call__(loc=0.0, scale=1.0, size=size, **kwargs) - - -standard_normal = StandardNormalRV() + return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype) class HalfNormalRV(ScipyRandomVariable): diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 62b0787a4e..700daf91fe 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -218,9 +218,9 @@ def __init__( if namespace is None: from pytensor.tensor.random import basic # pylint: disable=import-self - self.namespaces = [basic] + self.namespaces = [(basic, set(basic.__all__))] else: - self.namespaces = [namespace] + self.namespaces = [(namespace, set(namespace.__all__))] self.default_instance_seed = seed self.state_updates = [] @@ -235,22 +235,20 @@ def rng_ctor(seed): def __getattr__(self, obj): ns_obj = next( - (getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None + ( + getattr(ns, obj) + for ns, all_ in self.namespaces + if obj in all_ and hasattr(ns, obj) + ), + None, ) if ns_obj is None: raise AttributeError(f"No attribute {obj}.") - from pytensor.tensor.random.op import RandomVariable - - if isinstance(ns_obj, RandomVariable): - - @wraps(ns_obj) - def meta_obj(*args, **kwargs): - return self.gen(ns_obj, *args, **kwargs) - - else: - raise AttributeError(f"No attribute {obj}.") + @wraps(ns_obj) + def meta_obj(*args, **kwargs): + return self.gen(ns_obj, *args, **kwargs) setattr(self, obj, meta_obj) return getattr(self, obj) diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index a503878490..28ee2b94e0 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -114,7 +114,7 @@ def test_basics(self, rng_ctor): assert hasattr(random, "standard_normal") with pytest.raises(AttributeError): - np_random = RandomStream(namespace=np, rng_ctor=rng_ctor) + np_random = RandomStream(namespace=np.random, rng_ctor=rng_ctor) np_random.ndarray fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates()) From 630e642ef88488936708c094b7284d3898bedabb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 May 2024 12:19:57 +0200 Subject: [PATCH 06/15] Fix parameter ordering in test_uniform_samples --- tests/tensor/random/test_basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 1eb2afe6e0..080f99e36d 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -141,7 +141,7 @@ def test_fn(*args, random_state=None, **kwargs): @pytest.mark.parametrize( - "u, l, size", + "l, u, size", [ (np.array(10, dtype=config.floatX), np.array(20, dtype=config.floatX), None), (np.array(10, dtype=config.floatX), np.array(20, dtype=config.floatX), []), @@ -152,8 +152,8 @@ def test_fn(*args, random_state=None, **kwargs): ), ], ) -def test_uniform_samples(u, l, size): - compare_sample_values(uniform, u, l, size=size) +def test_uniform_samples(l, u, size): + compare_sample_values(uniform, l, u, size=size) def test_uniform_default_args(): From adc6f624b9ac78a5d8da0711eba72d1c925cba92 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 May 2024 12:17:29 +0200 Subject: [PATCH 07/15] Work-around for numpy bug in choice with size=() --- pytensor/tensor/random/basic.py | 4 ++++ tests/tensor/random/test_basic.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 8bd20d809b..91fe8fdd8e 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -2084,6 +2084,10 @@ def rng_fn(self, *params): batch_ndim = max(batch_ndim, size_ndim) if batch_ndim == 0: + # Numpy choice fails with size=() if a.ndim > 1 is batched + # https://github.com/numpy/numpy/issues/26518 + if core_shape == (): + core_shape = None return rng.choice(a, p=p, size=core_shape, replace=False) # Numpy choice doesn't have a concept of batch dims diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 080f99e36d..38ccc8d270 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1422,6 +1422,15 @@ def test_choice_samples(): compare_sample_values(choice, pt.as_tensor_variable([1, 2, 3]), 2, replace=True) +def test_choice_scalar_size(): + np.testing.assert_array_equal( + choice([[1, 2, 3]], size=(), replace=True).eval(), [1, 2, 3] + ) + np.testing.assert_array_equal( + choice([[1, 2, 3]], size=(), replace=False).eval(), [1, 2, 3] + ) + + def test_permutation_samples(): compare_sample_values( permutation, From 93638d3eafd6446d25a1ba3a90ef1ed16c58a3a1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 Apr 2024 18:54:38 +0200 Subject: [PATCH 08/15] Introduce signature instead of ndim_supp and ndims_params --- pytensor/tensor/random/basic.py | 165 +++++++------------- pytensor/tensor/random/op.py | 119 ++++++++++---- pytensor/tensor/random/rewriting/jax.py | 8 +- pytensor/tensor/random/utils.py | 9 +- pytensor/tensor/utils.py | 5 +- tests/link/jax/test_random.py | 6 +- tests/tensor/random/rewriting/test_basic.py | 67 +++----- tests/tensor/random/test_basic.py | 26 +-- tests/tensor/random/test_op.py | 15 +- 9 files changed, 198 insertions(+), 222 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 91fe8fdd8e..3290f22510 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -13,7 +13,6 @@ from pytensor.tensor.random.utils import ( broadcast_params, normalize_size_param, - supp_shape_from_ref_param_shape, ) from pytensor.tensor.random.var import ( RandomGeneratorSharedVariable, @@ -91,8 +90,7 @@ class UniformRV(RandomVariable): """ name = "uniform" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Uniform", "\\operatorname{Uniform}") @@ -146,8 +144,7 @@ class TriangularRV(RandomVariable): """ name = "triangular" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "floatX" _print_name = ("Triangular", "\\operatorname{Triangular}") @@ -202,8 +199,7 @@ class BetaRV(RandomVariable): """ name = "beta" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Beta", "\\operatorname{Beta}") @@ -249,8 +245,7 @@ class NormalRV(RandomVariable): """ name = "normal" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Normal", "\\operatorname{Normal}") @@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable): """ name = "halfnormal" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("HalfNormal", "\\operatorname{HalfNormal}") @@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable): """ name = "lognormal" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("LogNormal", "\\operatorname{LogNormal}") @@ -434,8 +427,7 @@ class GammaRV(RandomVariable): """ name = "gamma" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Gamma", "\\operatorname{Gamma}") @@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable): """ name = "pareto" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Pareto", "\\operatorname{Pareto}") @@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable): """ name = "gumbel" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Gumbel", "\\operatorname{Gumbel}") @@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable): """ name = "exponential" - ndim_supp = 0 - ndims_params = [0] + signature = "()->()" dtype = "floatX" _print_name = ("Exponential", "\\operatorname{Exponential}") @@ -724,8 +713,7 @@ class WeibullRV(RandomVariable): """ name = "weibull" - ndim_supp = 0 - ndims_params = [0] + signature = "()->()" dtype = "floatX" _print_name = ("Weibull", "\\operatorname{Weibull}") @@ -769,8 +757,7 @@ class LogisticRV(RandomVariable): """ name = "logistic" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Logistic", "\\operatorname{Logistic}") @@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable): """ name = "vonmises" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("VonMises", "\\operatorname{VonMises}") @@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable): """ name = "multivariate_normal" - ndim_supp = 1 - ndims_params = [1, 2] + signature = "(n),(n,n)->(n)" dtype = "floatX" _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") - def _supp_shape_from_params(self, dist_params, param_shapes=None): - return supp_shape_from_ref_param_shape( - ndim_supp=self.ndim_supp, - dist_params=dist_params, - param_shapes=param_shapes, - ref_param_idx=0, - ) - def __call__(self, mean=None, cov=None, size=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. @@ -942,7 +919,7 @@ def rng_fn(cls, rng, mean, cov, size): mean = np.broadcast_to(mean, size + mean.shape[-1:]) cov = np.broadcast_to(cov, size + cov.shape[-2:]) else: - mean, cov = broadcast_params([mean, cov], cls.ndims_params) + mean, cov = broadcast_params([mean, cov], [1, 2]) res = np.empty(mean.shape) for idx in np.ndindex(mean.shape[:-1]): @@ -973,19 +950,10 @@ class DirichletRV(RandomVariable): """ name = "dirichlet" - ndim_supp = 1 - ndims_params = [1] + signature = "(a)->(a)" dtype = "floatX" _print_name = ("Dirichlet", "\\operatorname{Dirichlet}") - def _supp_shape_from_params(self, dist_params, param_shapes=None): - return supp_shape_from_ref_param_shape( - ndim_supp=self.ndim_supp, - dist_params=dist_params, - param_shapes=param_shapes, - ref_param_idx=0, - ) - def __call__(self, alphas, size=None, **kwargs): r"""Draw samples from a dirichlet distribution. @@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable): """ name = "poisson" - ndim_supp = 0 - ndims_params = [0] + signature = "()->()" dtype = "int64" _print_name = ("Poisson", "\\operatorname{Poisson}") @@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable): """ name = "geometric" - ndim_supp = 0 - ndims_params = [0] + signature = "()->()" dtype = "int64" _print_name = ("Geometric", "\\operatorname{Geometric}") @@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable): """ name = "hypergeometric" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "int64" _print_name = ("HyperGeometric", "\\operatorname{HyperGeometric}") @@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable): """ name = "cauchy" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Cauchy", "\\operatorname{Cauchy}") @@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable): """ name = "halfcauchy" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}") @@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable): """ name = "invgamma" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("InverseGamma", "\\operatorname{InverseGamma}") @@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable): """ name = "wald" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name_ = ("Wald", "\\operatorname{Wald}") @@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable): """ name = "truncexpon" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "floatX" _print_name = ("TruncatedExponential", "\\operatorname{TruncatedExponential}") @@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable): """ name = "t" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "floatX" _print_name = ("StudentT", "\\operatorname{StudentT}") @@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable): """ name = "bernoulli" - ndim_supp = 0 - ndims_params = [0] + signature = "()->()" dtype = "int64" _print_name = ("Bernoulli", "\\operatorname{Bernoulli}") @@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable): """ name = "laplace" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "floatX" _print_name = ("Laplace", "\\operatorname{Laplace}") @@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable): """ name = "binomial" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "int64" _print_name = ("Binomial", "\\operatorname{Binomial}") @@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable): """ - name = "nbinom" - ndim_supp = 0 - ndims_params = [0, 0] + name = "negative_binomial" + signature = "(),()->()" dtype = "int64" _print_name = ("NegativeBinomial", "\\operatorname{NegativeBinomial}") @@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable): """ name = "beta_binomial" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "int64" _print_name = ("BetaBinomial", "\\operatorname{BetaBinomial}") @@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable): """ name = "gengamma" - ndim_supp = 0 - ndims_params = [0, 0, 0] + signature = "(),(),()->()" dtype = "floatX" _print_name = ("GeneralizedGamma", "\\operatorname{GeneralizedGamma}") @@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable): """ name = "multinomial" - ndim_supp = 1 - ndims_params = [0, 1] + signature = "(),(p)->(p)" dtype = "int64" _print_name = ("Multinomial", "\\operatorname{Multinomial}") @@ -1845,14 +1797,6 @@ def __call__(self, n, p, size=None, **kwargs): """ return super().__call__(n, p, size=size, **kwargs) - def _supp_shape_from_params(self, dist_params, param_shapes=None): - return supp_shape_from_ref_param_shape( - ndim_supp=self.ndim_supp, - dist_params=dist_params, - param_shapes=param_shapes, - ref_param_idx=1, - ) - @classmethod def rng_fn(cls, rng, n, p, size): if n.ndim > 0 or p.ndim > 1: @@ -1862,7 +1806,7 @@ def rng_fn(cls, rng, n, p, size): n = np.broadcast_to(n, size) p = np.broadcast_to(p, size + p.shape[-1:]) else: - n, p = broadcast_params([n, p], cls.ndims_params) + n, p = broadcast_params([n, p], [0, 1]) res = np.empty(p.shape, dtype=cls.dtype) for idx in np.ndindex(p.shape[:-1]): @@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable): """ name = "categorical" - ndim_supp = 0 - ndims_params = [1] + signature = "(p)->()" dtype = "int64" _print_name = ("Categorical", "\\operatorname{Categorical}") @@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable): """ name = "randint" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "int64" _print_name = ("randint", "\\operatorname{randint}") @@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable): """ name = "integers" - ndim_supp = 0 - ndims_params = [0, 0] + signature = "(),()->()" dtype = "int64" _print_name = ("integers", "\\operatorname{integers}") @@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None): a_ndim = a.type.ndim dtype = a.type.dtype + a_dims = [f"a{i}" for i in range(a_ndim)] + a_sig = ",".join(a_dims) + idx_dims = [f"s{i}" for i in range(core_shape_length)] + if a_ndim == 0: + p_sig = "a" + out_dims = idx_dims + else: + p_sig = a_dims[0] + out_dims = idx_dims + a_dims[1:] + out_sig = ",".join(out_dims) + if p is None: - ndims_params = [a_ndim, 1] + signature = f"({a_sig}),({core_shape_length})->({out_sig})" else: - ndims_params = [a_ndim, 1, 1] - ndim_supp = max(a_ndim - 1, 0) + core_shape_length + signature = f"({a_sig}),({p_sig}),({core_shape_length})->({out_sig})" - op = ChoiceWithoutReplacement( - ndim_supp=ndim_supp, - ndims_params=ndims_params, - dtype=dtype, - ) + op = ChoiceWithoutReplacement(signature=signature, dtype=dtype) params = (a, core_shape) if p is None else (a, p, core_shape) return op(*params, size=None, rng=rng) @@ -2247,10 +2194,12 @@ def permutation(x, **kwargs): x_dtype = x.type.dtype # PermutationRV has a signature () -> (x) if x is a scalar # and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x - ndim_supp = max(x_ndim, 1) - return PermutationRV(ndim_supp=ndim_supp, ndims_params=[x_ndim], dtype=x_dtype)( - x, **kwargs - ) + if x_ndim == 0: + signature = "()->(x)" + else: + arg_sig = ",".join(f"x{i}" for i in range(x_ndim)) + signature = f"({arg_sig})->({arg_sig})" + return PermutationRV(signature=signature, dtype=x_dtype)(x, **kwargs) __all__ = [ diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 8d1ea993ee..08cdf466db 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Sequence from copy import copy from typing import cast @@ -28,6 +29,7 @@ from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type_other import NoneConst +from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature from pytensor.tensor.variable import TensorVariable @@ -42,7 +44,7 @@ class RandomVariable(Op): _output_type_depends_on_input_value = True - __props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace") + __props__ = ("name", "signature", "dtype", "inplace") default_output = 1 def __init__( @@ -50,8 +52,9 @@ def __init__( name=None, ndim_supp=None, ndims_params=None, - dtype=None, + dtype: str | None = None, inplace=None, + signature: str | None = None, ): """Create a random variable `Op`. @@ -59,44 +62,63 @@ def __init__( ---------- name: str The `Op`'s display name. - ndim_supp: int - Total number of dimensions for a single draw of the random variable - (e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``). - ndims_params: list of int - Number of dimensions for each distribution parameter when the - parameters only specify a single drawn of the random variable - (e.g. a multivariate normal's mean is 1D and covariance is 2D, so - ``ndims_params = [1, 2]``). + signature: str + Numpy-like vectorized signature of the random variable. dtype: str (optional) The dtype of the sampled output. If the value ``"floatX"`` is given, then ``dtype`` is set to ``pytensor.config.floatX``. If ``None`` (the default), the `dtype` keyword must be set when `RandomVariable.make_node` is called. inplace: boolean (optional) - Determine whether or not the underlying rng state is updated - in-place or not (i.e. copied). + Determine whether the underlying rng state is mutated or copied. """ super().__init__() self.name = name or getattr(self, "name") - self.ndim_supp = ( - ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp") + + ndim_supp = ( + ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None) ) - self.ndims_params = ( - ndims_params if ndims_params is not None else getattr(self, "ndims_params") + if ndim_supp is not None: + warnings.warn( + "ndim_supp is deprecated. Provide signature instead.", FutureWarning + ) + self.ndim_supp = ndim_supp + ndims_params = ( + ndims_params + if ndims_params is not None + else getattr(self, "ndims_params", None) ) + if ndims_params is not None: + warnings.warn( + "ndims_params is deprecated. Provide signature instead.", FutureWarning + ) + if not isinstance(ndims_params, Sequence): + raise TypeError("Parameter ndims_params must be sequence type.") + self.ndims_params = tuple(ndims_params) + + self.signature = signature or getattr(self, "signature", None) + if self.signature is not None: + # Assume a single output. Several methods need to be updated to handle multiple outputs. + self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature) + self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig] + self.ndim_supp = len(self.output_sig) + else: + if ( + getattr(self, "ndim_supp", None) is None + or getattr(self, "ndims_params", None) is None + ): + raise ValueError("signature must be provided") + else: + self.signature = safe_signature(self.ndims_params, [self.ndim_supp]) + self.dtype = dtype or getattr(self, "dtype", None) self.inplace = ( inplace if inplace is not None else getattr(self, "inplace", False) ) - if not isinstance(self.ndims_params, Sequence): - raise TypeError("Parameter ndims_params must be sequence type.") - - self.ndims_params = tuple(self.ndims_params) - if self.inplace: self.destroy_map = {0: [0]} @@ -120,8 +142,31 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`, might have `support_shape=(steps,)`. """ + if self.signature is not None: + # Signature could indicate fixed numerical shapes + # As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html + output_sig = self.output_sig + core_out_shape = { + dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig + } + + # Try to infer missing support dims from signature of params + for param, param_sig, ndim_params in zip( + dist_params, self.inputs_sig, self.ndims_params + ): + if ndim_params == 0: + continue + for param_dim, dim in zip(param.shape[-ndim_params:], param_sig): + if dim in core_out_shape and core_out_shape[dim] is None: + core_out_shape[dim] = param_dim + + if all(dim is not None for dim in core_out_shape.values()): + # We have all we need + return [core_out_shape[dim] for dim in output_sig] + raise NotImplementedError( - "`_supp_shape_from_params` must be implemented for multivariate RVs" + "`_supp_shape_from_params` must be implemented for multivariate RVs " + "when signature is not sufficient to infer the support shape" ) def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray: @@ -129,7 +174,24 @@ def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray: return getattr(rng, self.name)(*args, **kwargs) def __str__(self): - props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:]) + # Only show signature from core props + if signature := self.signature: + # inp, out = signature.split("->") + # extended_signature = f"[rng],[size],{inp}->[rng],{out}" + # core_props = [extended_signature] + core_props = [f'"{signature}"'] + else: + # Far back compat + core_props = [str(self.ndim_supp), str(self.ndims_params)] + + # Add any extra props that the subclass may have + extra_props = [ + str(getattr(self, prop)) + for prop in self.__props__ + if prop not in RandomVariable.__props__ + ] + + props_str = ", ".join(core_props + extra_props) return f"{self.name}_rv{{{props_str}}}" def _infer_shape( @@ -298,11 +360,11 @@ def make_node(self, rng, size, dtype, *dist_params): dtype_idx = constant(all_dtypes.index(dtype), dtype="int64") else: dtype_idx = constant(dtype, dtype="int64") - dtype = all_dtypes[dtype_idx.data] - outtype = TensorType(dtype=dtype, shape=static_shape) - out_var = outtype() + dtype = all_dtypes[dtype_idx.data] + inputs = (rng, size, dtype_idx, *dist_params) + out_var = TensorType(dtype=dtype, shape=static_shape)() outputs = (rng.type(), out_var) return Apply(self, inputs, outputs) @@ -395,9 +457,8 @@ def vectorize_random_variable( # We extend it to accommodate the new input batch dimensions. # Otherwise, we assume the new size already has the right values - # Need to make parameters implicit broadcasting explicit - original_dist_params = node.inputs[3:] - old_size = node.inputs[1] + original_dist_params = op.dist_params(node) + old_size = op.size_param(node) len_old_size = get_vector_length(old_size) original_expanded_dist_params = explicit_expand_dims( diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index 1014e07a22..d86acfbd56 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -1,3 +1,5 @@ +import re + from pytensor.compile import optdb from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.db import SequenceDB @@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): a_vector_param = arange(a_scalar_param) new_props_dict = op._props_dict().copy() - new_ndims_params = list(op.ndims_params) - new_ndims_params[0] += 1 - new_props_dict["ndims_params"] = new_ndims_params + # Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)" + # I.e., we substitute the first `()` by `(a)` + new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1) new_op = type(op)(**new_props_dict) return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 700daf91fe..51fbf7e120 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params): def explicit_expand_dims( params: Sequence[TensorVariable], - ndim_params: tuple[int], + ndim_params: Sequence[int], size_length: int = 0, ) -> list[TensorVariable]: """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" @@ -137,7 +137,7 @@ def explicit_expand_dims( # See: https://github.com/pymc-devs/pytensor/issues/568 max_batch_dims = size_length else: - max_batch_dims = max(batch_dims) + max_batch_dims = max(batch_dims, default=0) new_params = [] for new_param, batch_dim in zip(params, batch_dims): @@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape( out: tuple Representing the support shape for a `RandomVariable` with the given `dist_params`. + Notes + _____ + This helper is no longer necessary when using signatures in `RandomVariable` subclasses. + + """ if ndim_supp <= 0: raise ValueError("ndim_supp must be greater than 0") diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 55aa7fc836..41218981a0 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -169,7 +169,8 @@ def broadcast_static_dim_lengths( _CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?" _ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*" -_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$" +# Allow no inputs +_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$" def _parse_gufunc_signature( @@ -200,6 +201,8 @@ def _parse_gufunc_signature( tuple(re.findall(_DIMENSION_NAME, arg)) for arg in re.findall(_ARGUMENT, arg_list) ] + if arg_list # ignore no inputs + else [] for arg_list in signature.split("->") ) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index c4d0dacdf7..5b3ca0c9c3 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -771,8 +771,7 @@ def test_random_unimplemented(): class NonExistentRV(RandomVariable): name = "non-existent" - ndim_supp = 0 - ndims_params = [] + signature = "->()" dtype = "floatX" def __call__(self, size=None, **kwargs): @@ -798,8 +797,7 @@ def test_random_custom_implementation(): class CustomRV(RandomVariable): name = "non-existent" - ndim_supp = 0 - ndims_params = [] + signature = "->()" dtype = "floatX" def __call__(self, size=None, **kwargs): diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index c51b8a1601..894caae063 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv( return new_out, f_inputs, dist_st, f_rewritten -def test_inplace_rewrites(): - out = normal(0, 1) - out.owner.inputs[0].default_update = out.owner.outputs[0] +class TestRVExpraProps(RandomVariable): + name = "test" + signature = "()->()" + __props__ = ("name", "signature", "dtype", "inplace", "extra") + dtype = "floatX" + _print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}") - assert out.owner.op.inplace is False + def __init__(self, extra, *args, **kwargs): + self.extra = extra + super().__init__(*args, **kwargs) - f = function( - [], - out, - mode="FAST_RUN", - ) - - (new_out, new_rng) = f.maker.fgraph.outputs - assert new_out.type == out.type - assert isinstance(new_out.owner.op, type(out.owner.op)) - assert new_out.owner.op.inplace is True - assert all( - np.array_equal(a.data, b.data) - for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) - ) - assert np.array_equal(new_out.owner.inputs[1].data, []) - - -def test_inplace_rewrites_extra_props(): - class Test(RandomVariable): - name = "test" - ndim_supp = 0 - ndims_params = [0] - __props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra") - dtype = "floatX" - _print_name = ("Test", "\\operatorname{Test}") - - def __init__(self, extra, *args, **kwargs): - self.extra = extra - super().__init__(*args, **kwargs) - - def make_node(self, rng, size, dtype, sigma): - return super().make_node(rng, size, dtype, sigma) - - def rng_fn(self, rng, sigma, size): - return rng.normal(scale=sigma, size=size) + def rng_fn(self, rng, dtype, sigma, size): + return rng.normal(scale=sigma, size=size) - out = Test(extra="some value")(1) - out.owner.inputs[0].default_update = out.owner.outputs[0] - assert out.owner.op.inplace is False +@pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")]) +def test_inplace_rewrites(rv_op): + out = rv_op(np.e) + node = out.owner + op = node.op + node.inputs[0].default_update = node.outputs[0] + assert op.inplace is False f = function( [], @@ -129,9 +105,10 @@ def rng_fn(self, rng, sigma, size): (new_out, new_rng) = f.maker.fgraph.outputs assert new_out.type == out.type - assert isinstance(new_out.owner.op, type(out.owner.op)) - assert new_out.owner.op.inplace is True - assert new_out.owner.op.extra == out.owner.op.extra + new_node = new_out.owner + new_op = new_node.op + assert isinstance(new_op, type(op)) + assert new_op._props_dict() == (op._props_dict() | {"inplace": True}) assert all( np.array_equal(a.data, b.data) for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 38ccc8d270..e3083c98ec 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester( rng = shared(rng_ctor()) # Batched a implicit size - a_core_ndim = 2 - core_shape_len = 1 rv_op = ChoiceWithoutReplacement( - ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, - ndims_params=[a_core_ndim, core_shape_len], + signature="(a0,a1),(1)->(s0,a1)", dtype="int64", ) @@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester( assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) # Explicit size broadcasts beyond a - a_core_ndim = 2 - core_shape_len = 2 rv_op = ChoiceWithoutReplacement( - ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, - ndims_params=[a_core_ndim, len(core_shape)], + signature="(a0,a1),(2)->(s0,s1,a1)", dtype="int64", ) @@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester( """ rng = shared(rng_ctor()) - # 3 ndims params indicates p is passed - a_core_ndim = 2 - core_shape_len = 1 rv_op = ChoiceWithoutReplacement( - ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, - ndims_params=[a_core_ndim, 1, 1], + signature="(a0,a1),(a0),(1)->(s0,a1)", dtype="int64", ) @@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester( # p and a are batched # Test implicit arange - a_core_ndim = 0 - core_shape_len = 2 rv_op = ChoiceWithoutReplacement( - ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, - ndims_params=[a_core_ndim, 1, 1], + signature="(),(a),(2)->(s0,s1)", dtype="int64", ) a = 6 @@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester( assert set(draw) == set(range(i, 6, 2)) # Size broadcasts beyond a - a_core_ndim = 2 - core_shape_len = 1 rv_op = ChoiceWithoutReplacement( - ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, - ndims_params=[a_core_ndim, 1, 1], + signature="(a0,a1),(a0),(1)->(s0,a1)", dtype="int64", ) a = np.arange(4 * 5 * 2).reshape((4, 5, 2)) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 80ee011015..46969c3a9f 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags): str_res = str( RandomVariable( "normal", - 0, - [0, 0], - "float32", - inplace=True, + signature="(),()->()", + dtype="float32", + inplace=False, ) ) - assert str_res == "normal_rv{0, (0, 0), float32, True}" + assert str_res == 'normal_rv{"(),()->()"}' # `ndims_params` should be a `Sequence` type with pytest.raises(TypeError, match="^Parameter ndims_params*"): @@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags): # Confirm that `inplace` works rv = RandomVariable( "normal", - 0, - [0, 0], - "normal", + signature="(),()->()", inplace=True, ) @@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags): assert rv.destroy_map == {0: [0]} # A no-params `RandomVariable` - rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=()) + rv = RandomVariable(name="test_rv", signature="->()") with pytest.raises(TypeError): rv.make_node(rng=1) From 362ebe4f6009424da49b71e30632b3515b7dd360 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 May 2024 12:25:56 +0200 Subject: [PATCH 09/15] Add RandomVariable Op helpers to retrieve rng, size, and dist_params from a node, for readability --- pytensor/link/jax/dispatch/random.py | 8 +++--- pytensor/link/numba/dispatch/random.py | 28 ++++++++++++--------- pytensor/tensor/random/op.py | 12 +++++++++ pytensor/tensor/random/rewriting/basic.py | 2 +- tests/tensor/random/rewriting/test_basic.py | 8 +++--- 5 files changed, 37 insertions(+), 21 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index d07091d099..ba9769ddf2 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -88,7 +88,7 @@ def jax_typify_Generator(rng, **kwargs): @jax_funcify.register(ptr.RandomVariable) -def jax_funcify_RandomVariable(op, node, **kwargs): +def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): """JAX implementation of random variables.""" rv = node.outputs[1] out_dtype = rv.type.dtype @@ -101,7 +101,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs): if None in static_size: # Sometimes size can be constant folded during rewrites, # without the RandomVariable node being updated with new static types - size_param = node.inputs[1] + size_param = op.size_param(node) if isinstance(size_param, Constant): size_tuple = tuple(size_param.data) # PyTensor uses empty size to represent size = None @@ -304,11 +304,11 @@ def sample_fn(rng, size, dtype, df, loc, scale): @jax_sample_fn.register(ptr.ChoiceWithoutReplacement) -def jax_funcify_choice(op, node): +def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): """JAX implementation of `ChoiceRV`.""" batch_ndim = op.batch_ndim(node) - a, *p, core_shape = node.inputs[3:] + a, *p, core_shape = op.dist_params(node) a_core_ndim, *p_core_ndim, _ = op.ndims_params if batch_ndim and a_core_ndim == 0: diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index f0b5508652..0bc4d4a890 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -96,11 +96,14 @@ def make_numba_random_fn(node, np_random_func): The functions generated here add parameter broadcasting and the ``size`` argument to the Numba-supported scalar ``np.random`` functions. """ - if not isinstance(node.inputs[0].type, RandomStateType): + op: ptr.RandomVariable = node.op + rng_param = op.rng_param(node) + if not isinstance(rng_param.type, RandomStateType): raise TypeError("Numba does not support NumPy `Generator`s") - tuple_size = int(get_vector_length(node.inputs[1])) - size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) + tuple_size = int(get_vector_length(op.size_param(node))) + dist_params = op.dist_params(node) + size_dims = tuple_size - max(i.ndim for i in dist_params) # Make a broadcast-capable version of the Numba supported scalar sampling # function @@ -126,7 +129,7 @@ def make_numba_random_fn(node, np_random_func): ) bcast_fn_input_names = ", ".join( - [unique_names(i, force_unique=True) for i in node.inputs[3:]] + [unique_names(i, force_unique=True) for i in dist_params] ) bcast_fn_global_env = { "np_random_func": np_random_func, @@ -143,7 +146,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): ) random_fn_input_names = ", ".join( - ["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]] + ["rng", "size", "dtype"] + [unique_names(i) for i in dist_params] ) # Now, create a Numba JITable function that implements the `size` parameter @@ -244,7 +247,8 @@ def create_numba_random_fn( suffix_sep="_", ) - np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]] + dist_params = op.dist_params(node) + np_names = [unique_names(i, force_unique=True) for i in dist_params] np_input_names = ", ".join(np_names) np_random_fn_src = f""" @numba_vectorize @@ -300,9 +304,9 @@ def body_fn(a): @numba_funcify.register(ptr.CategoricalRV) -def numba_funcify_CategoricalRV(op, node, **kwargs): +def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype - size_len = int(get_vector_length(node.inputs[1])) + size_len = int(get_vector_length(op.size_param(node))) p_ndim = node.inputs[-1].ndim @numba_basic.numba_njit @@ -331,9 +335,9 @@ def categorical_rv(rng, size, dtype, p): @numba_funcify.register(ptr.DirichletRV) def numba_funcify_DirichletRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype - alphas_ndim = node.inputs[3].type.ndim + alphas_ndim = op.dist_params(node)[0].type.ndim neg_ind_shape_len = -alphas_ndim + 1 - size_len = int(get_vector_length(node.inputs[1])) + size_len = int(get_vector_length(op.size_param(node))) if alphas_ndim > 1: @@ -400,9 +404,9 @@ def choice_without_replacement_rv(rng, size, dtype, a, core_shape): @numba_funcify.register(ptr.PermutationRV) -def numba_funcify_permutation(op, node, **kwargs): +def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): # PyTensor uses size=() to represent size=None - size_is_none = node.inputs[1].type.shape == (0,) + size_is_none = op.size_param(node).type.shape == (0,) batch_ndim = op.batch_ndim(node) x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 08cdf466db..f92be396af 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -372,6 +372,18 @@ def make_node(self, rng, size, dtype, *dist_params): def batch_ndim(self, node: Apply) -> int: return cast(int, node.default_output().type.ndim - self.ndim_supp) + def rng_param(self, node) -> Variable: + """Return the node input corresponding to the rng""" + return node.inputs[0] + + def size_param(self, node) -> Variable: + """Return the node input corresponding to the size""" + return node.inputs[1] + + def dist_params(self, node) -> Sequence[Variable]: + """Return the node inpust corresponding to dist params""" + return node.inputs[3:] + def perform(self, node, inputs, outputs): rng_var_out, smpl_out = outputs diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index a7220f520e..5cb84b9bf3 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -255,7 +255,7 @@ def is_nd_advanced_idx(idx, dtype): return False # Check that indexing does not act on support dims - batch_ndims = rv.ndim - rv_op.ndim_supp + batch_ndims = rv_op.batch_ndim(rv_node) # We decompose the boolean indexes, which makes it clear whether they act on support dims or not non_bool_indices = tuple( chain.from_iterable( diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 894caae063..9329554e2e 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -111,9 +111,9 @@ def test_inplace_rewrites(rv_op): assert new_op._props_dict() == (op._props_dict() | {"inplace": True}) assert all( np.array_equal(a.data, b.data) - for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) + for a, b in zip(new_op.dist_params(new_node), op.dist_params(node)) ) - assert np.array_equal(new_out.owner.inputs[1].data, []) + assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data) @config.change_flags(compute_test_value="raise") @@ -400,7 +400,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): assert new_out.owner.op == dist_op assert all( isinstance(i.owner.op, DimShuffle) - for i in new_out.owner.inputs[3:] + for i in new_out.owner.op.dist_params(new_out.owner) if i.owner ) else: @@ -793,7 +793,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): assert isinstance(new_out.owner.op, RandomVariable) assert all( isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor) - for i in new_out.owner.inputs[3:] + for i in new_out.owner.op.dist_params(new_out.owner) if i.owner ) else: From dbcdaa9f802df10d4d2840ed9496c287dafb9fc4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 10 May 2024 10:15:48 +0200 Subject: [PATCH 10/15] Remove RandomVariable dtype input --- pytensor/link/jax/dispatch/random.py | 4 +- pytensor/link/numba/dispatch/random.py | 17 ++- pytensor/tensor/random/op.py | 68 ++++++----- pytensor/tensor/random/rewriting/basic.py | 12 +- tests/tensor/random/rewriting/test_basic.py | 11 +- tests/tensor/random/test_op.py | 128 ++++++++++---------- 6 files changed, 118 insertions(+), 122 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index ba9769ddf2..dc22a07bfe 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): if None in static_size: assert_size_argument_jax_compatible(node) - def sample_fn(rng, size, dtype, *parameters): + def sample_fn(rng, size, *parameters): # PyTensor uses empty size to represent size = None if jax.numpy.asarray(size).shape == (0,): size = None @@ -122,7 +122,7 @@ def sample_fn(rng, size, dtype, *parameters): else: - def sample_fn(rng, size, dtype, *parameters): + def sample_fn(rng, size, *parameters): return jax_sample_fn(op, node=node)( rng, static_size, out_dtype, *parameters ) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 0bc4d4a890..e2bc51da7e 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func): "size_dims", "rng", "size", - "dtype", ], suffix_sep="_", ) @@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): ) random_fn_input_names = ", ".join( - ["rng", "size", "dtype"] + [unique_names(i) for i in dist_params] + ["rng", "size"] + [unique_names(i) for i in dist_params] ) # Now, create a Numba JITable function that implements the `size` parameter @@ -243,7 +242,7 @@ def create_numba_random_fn( np_global_env["numba_vectorize"] = numba_basic.numba_vectorize unique_names = unique_name_generator( - [np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"], + [np_random_fn_name, *np_global_env.keys(), "rng", "size"], suffix_sep="_", ) @@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): p_ndim = node.inputs[-1].ndim @numba_basic.numba_njit - def categorical_rv(rng, size, dtype, p): + def categorical_rv(rng, size, p): if not size_len: size_tpl = p.shape[:-1] else: @@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs): if alphas_ndim > 1: @numba_basic.numba_njit - def dirichlet_rv(rng, size, dtype, alphas): + def dirichlet_rv(rng, size, alphas): if size_len > 0: size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) if ( @@ -365,7 +364,7 @@ def dirichlet_rv(rng, size, dtype, alphas): else: @numba_basic.numba_njit - def dirichlet_rv(rng, size, dtype, alphas): + def dirichlet_rv(rng, size, alphas): size = numba_ndarray.to_fixed_tuple(size, size_len) return (rng, np.random.dirichlet(alphas, size)) @@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs): if op.has_p_param: @numba_basic.numba_njit - def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape): + def choice_without_replacement_rv(rng, size, a, p, core_shape): core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) samples = np.random.choice(a, size=core_shape, replace=False, p=p) return (rng, samples) else: @numba_basic.numba_njit - def choice_without_replacement_rv(rng, size, dtype, a, core_shape): + def choice_without_replacement_rv(rng, size, a, core_shape): core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) samples = np.random.choice(a, size=core_shape, replace=False) return (rng, samples) @@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] @numba_basic.numba_njit - def permutation_rv(rng, size, dtype, x): + def permutation_rv(rng, size, x): if batch_ndim: x_core_shape = x.shape[x_batch_ndim:] if size_is_none: diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index f92be396af..710981ed2e 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -27,7 +27,7 @@ normalize_size_param, ) from pytensor.tensor.shape import shape_tuple -from pytensor.tensor.type import TensorType, all_dtypes +from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import NoneConst from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature from pytensor.tensor.variable import TensorVariable @@ -65,7 +65,7 @@ def __init__( signature: str Numpy-like vectorized signature of the random variable. dtype: str (optional) - The dtype of the sampled output. If the value ``"floatX"`` is + The default dtype of the sampled output. If the value ``"floatX"`` is given, then ``dtype`` is set to ``pytensor.config.floatX``. If ``None`` (the default), the `dtype` keyword must be set when `RandomVariable.make_node` is called. @@ -287,8 +287,8 @@ def extract_batch_shape(p, ps, n): return shape def infer_shape(self, fgraph, node, input_shapes): - _, size, _, *dist_params = node.inputs - _, size_shape, _, *param_shapes = input_shapes + _, size, *dist_params = node.inputs + _, size_shape, *param_shapes = input_shapes try: size_len = get_vector_length(size) @@ -302,14 +302,34 @@ def infer_shape(self, fgraph, node, input_shapes): return [None, list(shape)] def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): - res = super().__call__(rng, size, dtype, *args, **kwargs) + if dtype is None: + dtype = self.dtype + if dtype == "floatX": + dtype = config.floatX + + # We need to recreate the Op with the right dtype + if dtype != self.dtype: + # Check we are not switching from float to int + if self.dtype is not None: + if dtype.startswith("float") != self.dtype.startswith("float"): + raise ValueError( + f"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}" + ) + props = self._props_dict() + props["dtype"] = dtype + new_op = type(self)(**props) + return new_op.__call__( + *args, size=size, name=name, rng=rng, dtype=dtype, **kwargs + ) + + res = super().__call__(rng, size, *args, **kwargs) if name is not None: res.name = name return res - def make_node(self, rng, size, dtype, *dist_params): + def make_node(self, rng, size, *dist_params): """Create a random variable node. Parameters @@ -349,23 +369,10 @@ def make_node(self, rng, size, dtype, *dist_params): shape = self._infer_shape(size, dist_params) _, static_shape = infer_static_shape(shape) - dtype = self.dtype or dtype - if dtype == "floatX": - dtype = config.floatX - elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes): - raise TypeError("dtype is unspecified") - - if isinstance(dtype, str): - dtype_idx = constant(all_dtypes.index(dtype), dtype="int64") - else: - dtype_idx = constant(dtype, dtype="int64") - - dtype = all_dtypes[dtype_idx.data] - - inputs = (rng, size, dtype_idx, *dist_params) - out_var = TensorType(dtype=dtype, shape=static_shape)() - outputs = (rng.type(), out_var) + inputs = (rng, size, *dist_params) + out_type = TensorType(dtype=self.dtype, shape=static_shape) + outputs = (rng.type(), out_type()) return Apply(self, inputs, outputs) @@ -382,14 +389,12 @@ def size_param(self, node) -> Variable: def dist_params(self, node) -> Sequence[Variable]: """Return the node inpust corresponding to dist params""" - return node.inputs[3:] + return node.inputs[2:] def perform(self, node, inputs, outputs): rng_var_out, smpl_out = outputs - rng, size, dtype, *args = inputs - - out_var = node.outputs[1] + rng, size, *args = inputs # If `size == []`, that means no size is enforced, and NumPy is trusted # to draw the appropriate number of samples, NumPy uses `size=None` to @@ -408,11 +413,8 @@ def perform(self, node, inputs, outputs): smpl_val = self.rng_fn(rng, *([*args, size])) - if ( - not isinstance(smpl_val, np.ndarray) - or str(smpl_val.dtype) != out_var.type.dtype - ): - smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype) + if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype: + smpl_val = _asarray(smpl_val, dtype=self.dtype) smpl_out[0] = smpl_val @@ -463,7 +465,7 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor): @_vectorize_node.register(RandomVariable) def vectorize_random_variable( - op: RandomVariable, node: Apply, rng, size, dtype, *dist_params + op: RandomVariable, node: Apply, rng, size, *dist_params ) -> Apply: # If size was provided originally and a new size hasn't been provided, # We extend it to accommodate the new input batch dimensions. @@ -491,4 +493,4 @@ def vectorize_random_variable( new_size_dims = broadcasted_batch_shape[:new_ndim] size = concatenate([new_size_dims, size]) - return op.make_node(rng, size, dtype, *dist_params) + return op.make_node(rng, size, *dist_params) diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 5cb84b9bf3..0da065835b 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node): if not isinstance(node.op, RandomVariable): return - rng, size, dtype, *dist_params = node.inputs + rng, size, *dist_params = node.inputs dist_params = broadcast_params(dist_params, node.op.ndims_params) @@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node): else: return - new_node = node.op.make_node(rng, None, dtype, *dist_params) + new_node = node.op.make_node(rng, None, *dist_params) if config.compute_test_value != "off": compute_test_value(new_node) @@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node): return False rv_op = rv_node.op - rng, size, dtype, *dist_params = rv_node.inputs + rng, size, *dist_params = rv_node.inputs rv = rv_node.default_output() # Check that Dimshuffle does not affect support dims @@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node): ) new_dist_params.append(param.dimshuffle(param_new_order)) - new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) + new_node = rv_op.make_node(rng, new_size, *new_dist_params) if config.compute_test_value != "off": compute_test_value(new_node) @@ -233,7 +233,7 @@ def is_nd_advanced_idx(idx, dtype): return None rv_op = rv_node.op - rng, size, dtype, *dist_params = rv_node.inputs + rng, size, *dist_params = rv_node.inputs # Parse indices idx_list = getattr(subtensor_op, "idx_list", None) @@ -346,7 +346,7 @@ def is_nd_advanced_idx(idx, dtype): new_dist_params.append(batch_param[tuple(batch_indices)]) # Create new RV - new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) + new_node = rv_op.make_node(rng, new_size, *new_dist_params) new_rv = new_node.default_output() copy_stack_trace(rv, new_rv) diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 9329554e2e..f8f70adc10 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -12,6 +12,7 @@ from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.basic import ( + NormalRV, categorical, dirichlet, multinomial, @@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ) if lifted: - assert new_out.owner.op == dist_op + assert isinstance(new_out.owner.op, type(dist_op)) assert all( isinstance(i.owner.op, DimShuffle) for i in new_out.owner.op.dist_params(new_out.owner) @@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions(): subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner assert subtensor_node == y.owner assert isinstance(subtensor_node.op, Subtensor) - assert subtensor_node.inputs[0].owner.op == normal + assert isinstance(subtensor_node.inputs[0].owner.op, NormalRV) z = pt.ones(x.shape) - x[1] @@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions(): EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner - assert rv_node.op == normal + assert isinstance(rv_node.op, NormalRV) assert isinstance(rv_node.inputs[-1].owner.op, Subtensor) assert isinstance(rv_node.inputs[-2].owner.op, Subtensor) @@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions(): dimshuffle_node = fg.outputs[0].owner.inputs[1].owner assert dimshuffle_node == y.owner assert isinstance(dimshuffle_node.op, DimShuffle) - assert dimshuffle_node.inputs[0].owner.op == normal + assert isinstance(dimshuffle_node.inputs[0].owner.op, NormalRV) z = pt.ones(x.shape) - y @@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions(): EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner - assert rv_node.op == normal + assert isinstance(rv_node.op, NormalRV) assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle) assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 46969c3a9f..bb0b5cbf03 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -3,14 +3,14 @@ import pytensor.tensor as pt from pytensor import config, function -from pytensor.gradient import NullTypeGradError, grad -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph from pytensor.raise_op import Assert from pytensor.tensor.math import eq from pytensor.tensor.random import normal +from pytensor.tensor.random.basic import NormalRV from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.shape import specify_shape -from pytensor.tensor.type import all_dtypes, iscalar, tensor +from pytensor.tensor.type import iscalar, tensor @pytest.fixture(scope="function", autouse=False) @@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags): inplace=True, )(0, 1, size={1, 2}) - # No dtype - with pytest.raises(TypeError, match="^dtype*"): - RandomVariable( - "normal", - 0, - [0, 0], - inplace=True, - )(0, 1) - # Confirm that `inplace` works rv = RandomVariable( "normal", @@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags): rv_shape = rv._infer_shape(pt.constant([]), (), []) assert rv_shape.equals(pt.constant([], dtype="int64")) - # Integer-specified `dtype` - dtype_1 = all_dtypes[1] - rv_node = rv.make_node(None, None, 1) - rv_out = rv_node.outputs[1] - rv_out.tag.test_value = 1 + # `dtype` is respected + rv = RandomVariable("normal", signature="(),()->()", dtype="int32") + with config.change_flags(compute_test_value="off"): + rv_out = rv() + assert rv_out.dtype == "int32" + rv_out = rv(dtype="int64") + assert rv_out.dtype == "int64" - assert rv_out.dtype == dtype_1 - - with pytest.raises(NullTypeGradError): - grad(rv_out, [rv_node.inputs[0]]) + with pytest.raises( + ValueError, + match="Cannot change the dtype of a normal RV from int32 to float32", + ): + assert rv(dtype="float32").dtype == "float32" def test_RandomVariable_bcast(strict_test_value_flags): @@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape(): assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3) -def test_vectorize_node(): +def test_vectorize(): vec = tensor(shape=(None,)) mat = tensor(shape=(None, None)) # Test without size - node = normal(vec).owner - new_inputs = node.inputs.copy() - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal - assert vect_node.inputs[3] is mat + out = normal(vec) + vect_node = vectorize_graph(out, {vec: mat}).owner + assert isinstance(vect_node.op, NormalRV) + assert vect_node.op.dist_params(vect_node)[0] is mat # Test with size, new size provided - node = normal(vec, size=(3,)).owner - new_inputs = node.inputs.copy() - new_inputs[1] = (2, 3) # size - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal - assert tuple(vect_node.inputs[1].eval()) == (2, 3) - assert vect_node.inputs[3] is mat + size = pt.as_tensor(np.array((3,), dtype="int64")) + out = normal(vec, size=size) + vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner + assert isinstance(vect_node.op, NormalRV) + assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3) + assert vect_node.op.dist_params(vect_node)[0] is mat # Test with size, new size not provided - node = normal(vec, size=(3,)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal - assert vect_node.inputs[3] is mat + out = normal(vec, size=(3,)) + vect_node = vectorize_graph(out, {vec: mat}).owner + assert isinstance(vect_node.op, NormalRV) + assert vect_node.op.dist_params(vect_node)[0] is mat assert tuple( - vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)}) + vect_node.op.size_param(vect_node).eval( + {mat: np.zeros((2, 3), dtype=config.floatX)} + ) ) == (2, 3) # Test parameter broadcasting - node = normal(vec).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma) + new_mu = tensor("mu", shape=(10, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner + assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 5) # Test parameter broadcasting with non-expanding size - node = normal(vec, size=(5,)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma, size=(5,)) + new_mu = tensor("mu", shape=(10, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner + assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 5) - node = normal(vec, size=(5,)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(1, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma, size=(5,)) + new_mu = tensor("mu", shape=(1, 5)) # mu + new_sigma = tensor("sigma", shape=(10,)) # sigma + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner + assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 5) # Test parameter broadcasting with expanding size - node = normal(vec, size=(2, 5)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) - assert vect_node.op is normal + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma, size=(2, 5)) + new_mu = tensor("mu", shape=(1, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner + assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 2, 5) From 6ce8468a1e54b564edbd0a01c2f791c3752aa6ae Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 May 2024 21:38:46 +0200 Subject: [PATCH 11/15] Distinguish between size=None and size=() in RandomVariables --- pytensor/link/jax/dispatch/random.py | 14 ++--- pytensor/link/numba/dispatch/random.py | 49 ++++++++++------ pytensor/tensor/random/basic.py | 17 +++--- pytensor/tensor/random/op.py | 44 +++++--------- pytensor/tensor/random/rewriting/basic.py | 47 ++++++++------- pytensor/tensor/random/rewriting/jax.py | 4 +- pytensor/tensor/random/utils.py | 34 ++++++----- tests/tensor/random/rewriting/test_basic.py | 64 +++++++++++---------- tests/tensor/random/test_basic.py | 27 ++++----- tests/tensor/random/test_op.py | 15 ++++- 10 files changed, 163 insertions(+), 152 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index dc22a07bfe..17094aec53 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -12,6 +12,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.type_other import NoneTypeT try: @@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): rv = node.outputs[1] out_dtype = rv.type.dtype static_shape = rv.type.shape - batch_ndim = op.batch_ndim(node) # Try to pass static size directly to JAX @@ -102,11 +102,10 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): # Sometimes size can be constant folded during rewrites, # without the RandomVariable node being updated with new static types size_param = op.size_param(node) - if isinstance(size_param, Constant): - size_tuple = tuple(size_param.data) - # PyTensor uses empty size to represent size = None - if len(size_tuple): - static_size = tuple(size_param.data) + if isinstance(size_param, Constant) and not isinstance( + size_param.type, NoneTypeT + ): + static_size = tuple(size_param.data) # If one dimension has unknown size, either the size is determined # by a `Shape` operator in which case JAX will compile, or it is @@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): assert_size_argument_jax_compatible(node) def sample_fn(rng, size, *parameters): - # PyTensor uses empty size to represent size = None - if jax.numpy.asarray(size).shape == (0,): - size = None return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) else: diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index e2bc51da7e..f35d1b3fcd 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -21,6 +21,7 @@ ) from pytensor.tensor.basic import get_vector_length from pytensor.tensor.random.type import RandomStateType +from pytensor.tensor.type_other import NoneTypeT class RandomStateNumbaType(types.Type): @@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func): if not isinstance(rng_param.type, RandomStateType): raise TypeError("Numba does not support NumPy `Generator`s") - tuple_size = int(get_vector_length(op.size_param(node))) + size_param = op.size_param(node) + size_len = ( + None + if isinstance(size_param.type, NoneTypeT) + else int(get_vector_length(size_param)) + ) dist_params = op.dist_params(node) - size_dims = tuple_size - max(i.ndim for i in dist_params) # Make a broadcast-capable version of the Numba supported scalar sampling # function @@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func): "np_random_func", "numba_vectorize", "to_fixed_tuple", - "tuple_size", + "size_len", "size_dims", "rng", "size", @@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}): "out_dtype": out_dtype, } - if tuple_size > 0: + if size_len is not None: + size_dims = size_len - max(i.ndim for i in dist_params) + random_fn_body = dedent( f""" - size = to_fixed_tuple(size, tuple_size) + size = to_fixed_tuple(size, size_len) data = np.empty(size, dtype=out_dtype) for i in np.ndindex(size[:size_dims]): @@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): { "np": np, "to_fixed_tuple": numba_ndarray.to_fixed_tuple, - "tuple_size": tuple_size, + "size_len": size_len, "size_dims": size_dims, } ) @@ -305,19 +312,24 @@ def body_fn(a): @numba_funcify.register(ptr.CategoricalRV) def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype - size_len = int(get_vector_length(op.size_param(node))) + size_param = op.size_param(node) + size_len = ( + None + if isinstance(size_param.type, NoneTypeT) + else int(get_vector_length(size_param)) + ) p_ndim = node.inputs[-1].ndim @numba_basic.numba_njit def categorical_rv(rng, size, p): - if not size_len: + if size_len is None: size_tpl = p.shape[:-1] else: size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) p = np.broadcast_to(p, size_tpl + p.shape[-1:]) # Workaround https://github.com/numba/numba/issues/8975 - if not size_len and p_ndim == 1: + if size_len is None and p_ndim == 1: unif_samples = np.asarray(np.random.uniform(0, 1)) else: unif_samples = np.random.uniform(0, 1, size_tpl) @@ -336,13 +348,20 @@ def numba_funcify_DirichletRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype alphas_ndim = op.dist_params(node)[0].type.ndim neg_ind_shape_len = -alphas_ndim + 1 - size_len = int(get_vector_length(op.size_param(node))) + size_param = op.size_param(node) + size_len = ( + None + if isinstance(size_param.type, NoneTypeT) + else int(get_vector_length(size_param)) + ) if alphas_ndim > 1: @numba_basic.numba_njit def dirichlet_rv(rng, size, alphas): - if size_len > 0: + if size_len is None: + samples_shape = alphas.shape + else: size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) if ( 0 < alphas.ndim - 1 <= len(size_tpl) @@ -350,8 +369,6 @@ def dirichlet_rv(rng, size, alphas): ): raise ValueError("Parameters shape and size do not match.") samples_shape = size_tpl + alphas.shape[-1:] - else: - samples_shape = alphas.shape res = np.empty(samples_shape, dtype=out_dtype) alphas_bcast = np.broadcast_to(alphas, samples_shape) @@ -365,7 +382,8 @@ def dirichlet_rv(rng, size, alphas): @numba_basic.numba_njit def dirichlet_rv(rng, size, alphas): - size = numba_ndarray.to_fixed_tuple(size, size_len) + if size_len is not None: + size = numba_ndarray.to_fixed_tuple(size, size_len) return (rng, np.random.dirichlet(alphas, size)) return dirichlet_rv @@ -404,8 +422,7 @@ def choice_without_replacement_rv(rng, size, a, core_shape): @numba_funcify.register(ptr.PermutationRV) def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): - # PyTensor uses size=() to represent size=None - size_is_none = op.size_param(node).type.shape == (0,) + size_is_none = isinstance(op.size_param(node).type, NoneTypeT) batch_ndim = op.batch_ndim(node) x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 3290f22510..5c5665ef1f 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -914,12 +914,11 @@ def rng_fn(cls, rng, mean, cov, size): # multivariate normals (or any other multivariate distributions), # so we need to implement that here - size = tuple(size or ()) - if size: + if size is None: + mean, cov = broadcast_params([mean, cov], [1, 2]) + else: mean = np.broadcast_to(mean, size + mean.shape[-1:]) cov = np.broadcast_to(cov, size + cov.shape[-2:]) - else: - mean, cov = broadcast_params([mean, cov], [1, 2]) res = np.empty(mean.shape) for idx in np.ndindex(mean.shape[:-1]): @@ -1800,13 +1799,11 @@ def __call__(self, n, p, size=None, **kwargs): @classmethod def rng_fn(cls, rng, n, p, size): if n.ndim > 0 or p.ndim > 1: - size = tuple(size or ()) - - if size: + if size is None: + n, p = broadcast_params([n, p], [0, 1]) + else: n = np.broadcast_to(n, size) p = np.broadcast_to(p, size + p.shape[-1:]) - else: - n, p = broadcast_params([n, p], [0, 1]) res = np.empty(p.shape, dtype=cls.dtype) for idx in np.ndindex(p.shape[:-1]): @@ -2155,7 +2152,7 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): def rng_fn(self, rng, x, size): # We don't have access to the node in rng_fn :( x_batch_ndim = x.ndim - self.ndims_params[0] - batch_ndim = max(x_batch_ndim, len(size or ())) + batch_ndim = max(x_batch_ndim, 0 if size is None else len(size)) if batch_ndim: # rng.permutation has no concept of batch dims diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 710981ed2e..7b73c9de03 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -16,7 +16,6 @@ as_tensor_variable, concatenate, constant, - get_underlying_scalar_constant_value, get_vector_length, infer_static_shape, ) @@ -28,7 +27,7 @@ ) from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst +from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature from pytensor.tensor.variable import TensorVariable @@ -196,10 +195,10 @@ def __str__(self): def _infer_shape( self, - size: TensorVariable, + size: TensorVariable | Variable, dist_params: Sequence[TensorVariable], param_shapes: Sequence[tuple[Variable, ...]] | None = None, - ) -> TensorVariable | tuple[ScalarVariable, ...]: + ) -> tuple[ScalarVariable | TensorVariable, ...]: """Compute the output shape given the size and distribution parameters. Parameters @@ -225,9 +224,9 @@ def _infer_shape( self._supp_shape_from_params(dist_params, param_shapes=param_shapes) ) - size_len = get_vector_length(size) + if not isinstance(size.type, NoneTypeT): + size_len = get_vector_length(size) - if size_len > 0: # Fail early when size is incompatible with parameters for i, (param, param_ndim_supp) in enumerate( zip(dist_params, self.ndims_params) @@ -281,21 +280,11 @@ def extract_batch_shape(p, ps, n): shape = batch_shape + supp_shape - if not shape: - shape = constant([], dtype="int64") - return shape def infer_shape(self, fgraph, node, input_shapes): _, size, *dist_params = node.inputs - _, size_shape, *param_shapes = input_shapes - - try: - size_len = get_vector_length(size) - except ValueError: - size_len = get_underlying_scalar_constant_value(size_shape[0]) - - size = tuple(size[n] for n in range(size_len)) + _, _, *param_shapes = input_shapes shape = self._infer_shape(size, dist_params, param_shapes=param_shapes) @@ -367,8 +356,8 @@ def make_node(self, rng, size, *dist_params): "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) - shape = self._infer_shape(size, dist_params) - _, static_shape = infer_static_shape(shape) + inferred_shape = self._infer_shape(size, dist_params) + _, static_shape = infer_static_shape(inferred_shape) inputs = (rng, size, *dist_params) out_type = TensorType(dtype=self.dtype, shape=static_shape) @@ -396,21 +385,14 @@ def perform(self, node, inputs, outputs): rng, size, *args = inputs - # If `size == []`, that means no size is enforced, and NumPy is trusted - # to draw the appropriate number of samples, NumPy uses `size=None` to - # represent that. Otherwise, NumPy expects a tuple. - if np.size(size) == 0: - size = None - else: - size = tuple(size) - - # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` - # otherwise. + # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: rng = copy(rng) rng_var_out[0] = rng + if size is not None: + size = tuple(size) smpl_val = self.rng_fn(rng, *([*args, size])) if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype: @@ -473,7 +455,9 @@ def vectorize_random_variable( original_dist_params = op.dist_params(node) old_size = op.size_param(node) - len_old_size = get_vector_length(old_size) + len_old_size = ( + None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size) + ) original_expanded_dist_params = explicit_expand_dims( original_dist_params, op.ndims_params, len_old_size diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 0da065835b..b1960927e6 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -7,7 +7,7 @@ from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.scalar import integer_types from pytensor.tensor import NoneConst -from pytensor.tensor.basic import constant, get_vector_length +from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.random.op import RandomVariable @@ -20,7 +20,7 @@ as_index_variable, get_idx_list, ) -from pytensor.tensor.type_other import SliceType +from pytensor.tensor.type_other import NoneTypeT, SliceType def is_rv_used_in_graph(base_rv, node, fgraph): @@ -83,27 +83,27 @@ def local_rv_size_lift(fgraph, node): rng, size, *dist_params = node.inputs + if isinstance(size.type, NoneTypeT): + return + dist_params = broadcast_params(dist_params, node.op.ndims_params) - if get_vector_length(size) > 0: - dist_params = [ - broadcast_to( - p, - ( - tuple(size) - + ( - tuple(p.shape)[-node.op.ndims_params[i] :] - if node.op.ndims_params[i] > 0 - else () - ) + dist_params = [ + broadcast_to( + p, + ( + tuple(size) + + ( + tuple(p.shape)[-node.op.ndims_params[i] :] + if node.op.ndims_params[i] > 0 + else () ) - if node.op.ndim_supp > 0 - else size, ) - for i, p in enumerate(dist_params) - ] - else: - return + if node.op.ndim_supp > 0 + else size, + ) + for i, p in enumerate(dist_params) + ] new_node = node.op.make_node(rng, None, *dist_params) @@ -159,11 +159,10 @@ def local_dimshuffle_rv_lift(fgraph, node): batched_dims = rv.ndim - rv_op.ndim_supp batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims) - # Make size explicit - missing_size_dims = batched_dims - get_vector_length(size) - if missing_size_dims > 0: - full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) - size = full_size[:missing_size_dims] + tuple(size) + if isinstance(size.type, NoneTypeT): + # Make size explicit + shape = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) + size = shape[:batched_dims] # Update the size to reflect the DimShuffled dimensions new_size = [ diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index d86acfbd56..21ddb32af9 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -158,7 +158,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): # No need to materialize arange return None - rng, size, dtype, a_scalar_param, *other_params = node.inputs + rng, size, a_scalar_param, *other_params = node.inputs if a_scalar_param.type.ndim > 0: # Automatic vectorization could have made this parameter batched, # there is no nice way to materialize a batched arange @@ -170,7 +170,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): # I.e., we substitute the first `()` by `(a)` new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1) new_op = type(op)(**new_props_dict) - return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs + return new_op.make_node(rng, size, a_vector_param, *other_params).outputs random_vars_opt = SequenceDB() diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 51fbf7e120..38329bbae7 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -9,8 +9,8 @@ from pytensor.compile.sharedvalue import shared from pytensor.graph.basic import Constant, Variable from pytensor.scalar import ScalarVariable -from pytensor.tensor import get_vector_length -from pytensor.tensor.basic import as_tensor_variable, cast, constant +from pytensor.tensor import NoneConst, get_vector_length +from pytensor.tensor.basic import as_tensor_variable, cast from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to from pytensor.tensor.math import maximum from pytensor.tensor.shape import shape_padleft, specify_shape @@ -124,7 +124,7 @@ def broadcast_params(params, ndims_params): def explicit_expand_dims( params: Sequence[TensorVariable], ndim_params: Sequence[int], - size_length: int = 0, + size_length: int | None = None, ) -> list[TensorVariable]: """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" @@ -132,9 +132,7 @@ def explicit_expand_dims( param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params) ] - if size_length: - # NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does - # See: https://github.com/pymc-devs/pytensor/issues/568 + if size_length is not None: max_batch_dims = size_length else: max_batch_dims = max(batch_dims, default=0) @@ -159,30 +157,30 @@ def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: def normalize_size_param( - size: int | np.ndarray | Variable | Sequence | None, + shape: int | np.ndarray | Variable | Sequence | None, ) -> Variable: """Create an PyTensor value for a ``RandomVariable`` ``size`` parameter.""" - if size is None: - size = constant([], dtype="int64") - elif isinstance(size, int): - size = as_tensor_variable([size], ndim=1) - elif not isinstance(size, np.ndarray | Variable | Sequence): + if shape is None or NoneConst.equals(shape): + return NoneConst + elif isinstance(shape, int): + shape = as_tensor_variable([shape], ndim=1) + elif not isinstance(shape, np.ndarray | Variable | Sequence): raise TypeError( "Parameter size must be None, an integer, or a sequence with integers." ) else: - size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64") + shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") - if not isinstance(size, Constant): + if not isinstance(shape, Constant): # This should help ensure that the length of non-constant `size`s # will be available after certain types of cloning (e.g. the kind # `Scan` performs) - size = specify_shape(size, (get_vector_length(size),)) + shape = specify_shape(shape, (get_vector_length(shape),)) - assert not any(s is None for s in size.type.shape) - assert size.dtype in int_dtypes + assert not any(s is None for s in shape.type.shape) + assert shape.dtype in int_dtypes - return size + return shape class RandomStream: diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index f8f70adc10..05f9fa17a4 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -30,6 +30,7 @@ from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from pytensor.tensor.type import iscalar, vector +from pytensor.tensor.type_other import NoneConst no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[])) @@ -44,20 +45,25 @@ def apply_local_rewrite_to_rv( p_pt.tag.test_value = p dist_params_pt.append(p_pt) - size_pt = [] - for s in size: - # To test DimShuffle with dropping dims we need that size dimension to be constant - if s == 1: - s_pt = constant(np.array(1, dtype="int32")) - else: - s_pt = iscalar() - s_pt.tag.test_value = s - size_pt.append(s_pt) + if size is None: + size_pt = NoneConst + else: + size_pt = [] + for s in size: + # To test DimShuffle with dropping dims we need that size dimension to be constant + if s == 1: + s_pt = constant(np.array(1, dtype="int32")) + else: + s_pt = iscalar() + s_pt.tag.test_value = s + size_pt.append(s_pt) dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name)) f_inputs = [ - p for p in dist_params_pt + size_pt if not isinstance(p, slice | Constant) + p + for p in dist_params_pt + ([] if size is None else size_pt) + if not isinstance(p, slice | Constant) ] mode = Mode( @@ -135,7 +141,7 @@ def test_inplace_rewrites(rv_op): np.array([0.0, 1.0], dtype=config.floatX), np.array(5.0, dtype=config.floatX), ], - [], + None, ), ( normal, @@ -180,7 +186,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): rng, ) - assert pt.get_vector_length(new_out.owner.inputs[1]) == 0 + assert new_out.owner.op.size_param(new_out.owner).data is None @pytest.mark.parametrize( @@ -194,7 +200,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): np.array([0.0, -100.0], dtype=np.float64), np.array(1e-6, dtype=np.float64), ), - (), + None, 1e-7, ), ( @@ -205,7 +211,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): np.array(-10.0, dtype=np.float64), np.array(1e-6, dtype=np.float64), ), - (), + None, 1e-7, ), ( @@ -216,7 +222,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): np.array(-10.0, dtype=np.float64), np.array(1e-6, dtype=np.float64), ), - (), + None, 1e-7, ), ( @@ -227,7 +233,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): np.arange(2 * 2 * 2).reshape((2, 2, 2)).astype(config.floatX), np.array(1e-6).astype(config.floatX), ), - (), + None, 1e-3, ), ( @@ -440,7 +446,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30, dtype=config.floatX).reshape(3, 5, 2), np.full((1, 5, 1), 1e-6), ), - (), + None, ), ( # `size`-only slice @@ -462,7 +468,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30, dtype=config.floatX).reshape(3, 5, 2), np.full((1, 5, 1), 1e-6), ), - (), + None, ), ( # `size`-only slice @@ -484,7 +490,7 @@ def rand_bool_mask(shape, rng=None): (0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX), 0.1 * np.arange(4).astype(dtype=config.floatX), ), - (), + None, ), # 5 ( @@ -570,7 +576,7 @@ def rand_bool_mask(shape, rng=None): dtype=config.floatX, ), ), - (), + None, ), ( # Univariate distribution with core-vector parameters @@ -627,7 +633,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), ( # Multidimensional boolean indexing @@ -638,7 +644,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), ( # Multidimensional boolean indexing @@ -649,7 +655,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), # 20 ( @@ -661,7 +667,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), ( # Multidimensional boolean indexing @@ -687,7 +693,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), ( # Multidimensional boolean indexing, @@ -703,7 +709,7 @@ def rand_bool_mask(shape, rng=None): np.arange(30).reshape(5, 3, 2), 1e-6, ), - (), + None, ), ( # Multivariate distribution: indexing dips into core dimension @@ -714,7 +720,7 @@ def rand_bool_mask(shape, rng=None): np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.eye(2).astype(config.floatX) * 1e-6, ), - (), + None, ), # 25 ( @@ -726,7 +732,7 @@ def rand_bool_mask(shape, rng=None): np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.eye(2).astype(config.floatX) * 1e-6, ), - (), + None, ), ( # Multivariate distribution: advanced integer indexing @@ -740,7 +746,7 @@ def rand_bool_mask(shape, rng=None): ), np.eye(3, dtype=config.floatX) * 1e-6, ), - (), + None, ), ( # Multivariate distribution: dummy slice "dips" into core dimension diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index e3083c98ec..84385e5bc3 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -212,7 +212,7 @@ def test_beta_samples(a, b, size): @pytest.mark.parametrize( "M, sd, size", [ - (pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, ()), + (pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, None), ( pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, @@ -223,10 +223,10 @@ def test_beta_samples(a, b, size): sd_pt, (2, M_pt), ), - (pt.zeros((M_pt,)), sd_pt, ()), + (pt.zeros((M_pt,)), sd_pt, None), (pt.zeros((M_pt,)), sd_pt, (M_pt,)), (pt.zeros((M_pt,)), sd_pt, (2, M_pt)), - (pt.zeros((M_pt,)), pt.ones((M_pt,)), ()), + (pt.zeros((M_pt,)), pt.ones((M_pt,)), None), (pt.zeros((M_pt,)), pt.ones((M_pt,)), (2, M_pt)), ( create_pytensor_param( @@ -244,9 +244,10 @@ def test_beta_samples(a, b, size): ) def test_normal_infer_shape(M, sd, size): rv = normal(M, sd, size=size) - rv_shape = list(normal._infer_shape(size or (), [M, sd], None)) + size_pt = rv.owner.op.size_param(rv.owner) + rv_shape = list(normal._infer_shape(size_pt, [M, sd], None)) - all_args = (M, sd, *size) + all_args = (M, sd, *(() if size is None else size)) fn_inputs = [ i for i in graph_inputs([a for a in all_args if isinstance(a, Variable)]) @@ -525,8 +526,8 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): mean = np.array([0.0], dtype=config.floatX) if cov is None: cov = np.array([[1.0]], dtype=config.floatX) - if size is None: - size = () + if size is not None: + size = tuple(size) return multivariate_normal.rng_fn(random_state, mean, cov, size) @@ -713,19 +714,20 @@ def test_dirichlet_rng(): @pytest.mark.parametrize( "M, size", [ - (pt.ones((M_pt,)), ()), + (pt.ones((M_pt,)), None), (pt.ones((M_pt,)), (M_pt + 1,)), (pt.ones((M_pt,)), (2, M_pt)), - (pt.ones((M_pt, M_pt + 1)), ()), + (pt.ones((M_pt, M_pt + 1)), None), (pt.ones((M_pt, M_pt + 1)), (M_pt + 2, M_pt)), (pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)), ], ) def test_dirichlet_infer_shape(M, size): rv = dirichlet(M, size=size) - rv_shape = list(dirichlet._infer_shape(size or (), [M], None)) + size_pt = rv.owner.op.size_param(rv.owner) + rv_shape = list(dirichlet._infer_shape(size_pt, [M], None)) - all_args = (M, *size) + all_args = (M, *(() if size is None else size)) fn_inputs = [ i for i in graph_inputs([a for a in all_args if isinstance(a, Variable)]) @@ -1620,8 +1622,7 @@ def test_unnatural_batched_dims(batch_dims_tester): @config.change_flags(compute_test_value="off") def test_pickle(): - # This is an interesting `Op` case, because it has `None` types and a - # conditional dtype + # This is an interesting `Op` case, because it has a conditional dtype sample_a = choice(5, replace=False, size=(2, 3)) a_pkl = pickle.dumps(sample_a) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index bb0b5cbf03..35e5f49c28 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -69,7 +69,7 @@ def test_RandomVariable_basics(strict_test_value_flags): # `RandomVariable._infer_shape` should handle no parameters rv_shape = rv._infer_shape(pt.constant([]), (), []) - assert rv_shape.equals(pt.constant([], dtype="int64")) + assert rv_shape == () # `dtype` is respected rv = RandomVariable("normal", signature="(),()->()", dtype="int32") @@ -299,3 +299,16 @@ def test_vectorize(): vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 2, 5) + + +def test_size_none_vs_empty(): + rv = RandomVariable( + "normal", + signature="(),()->()", + ) + assert rv([0], [1], size=None).type.shape == (1,) + + with pytest.raises( + ValueError, match="Size length is incompatible with batched dimensions" + ): + rv([0], [1], size=()) From 7af1412e834bbeca0213ffd068b11f4c8d669dc9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 May 2024 12:46:25 +0200 Subject: [PATCH 12/15] Add explicit expand_dims when building RandomVariable nodes --- pytensor/link/jax/dispatch/random.py | 25 ++++--------- pytensor/link/numba/dispatch/random.py | 6 --- pytensor/tensor/random/basic.py | 41 +++++++++++---------- pytensor/tensor/random/op.py | 28 +++++++------- pytensor/tensor/random/rewriting/jax.py | 10 ++++- tests/link/numba/test_random.py | 7 +++- tests/tensor/random/rewriting/test_basic.py | 12 +++++- 7 files changed, 66 insertions(+), 63 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 17094aec53..98b59d22b3 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -304,7 +304,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): """JAX implementation of `ChoiceRV`.""" batch_ndim = op.batch_ndim(node) - a, *p, core_shape = op.dist_params(node) a_core_ndim, *p_core_ndim, _ = op.ndims_params if batch_ndim and a_core_ndim == 0: @@ -313,12 +312,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): "A default JAX rewrite should have materialized the implicit arange" ) - a_batch_ndim = a.type.ndim - a_core_ndim - if op.has_p_param: - [p] = p - [p_core_ndim] = p_core_ndim - p_batch_ndim = p.type.ndim - p_core_ndim - def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) @@ -328,7 +321,7 @@ def sample_fn(rng, size, dtype, *parameters): else: a, core_shape = parameters p = None - core_shape = tuple(np.asarray(core_shape)) + core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim]) if batch_ndim == 0: sample = jax.random.choice( @@ -338,16 +331,16 @@ def sample_fn(rng, size, dtype, *parameters): else: if size is None: if p is None: - size = a.shape[:a_batch_ndim] + size = a.shape[:batch_ndim] else: size = jax.numpy.broadcast_shapes( - a.shape[:a_batch_ndim], - p.shape[:p_batch_ndim], + a.shape[:batch_ndim], + p.shape[:batch_ndim], ) - a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:]) + a = jax.numpy.broadcast_to(a, size + a.shape[batch_ndim:]) if p is not None: - p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:]) + p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:]) batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) @@ -381,7 +374,6 @@ def jax_sample_fn_permutation(op, node): """JAX implementation of `PermutationRV`.""" batch_ndim = op.batch_ndim(node) - x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] @@ -389,11 +381,10 @@ def sample_fn(rng, size, dtype, *parameters): (x,) = parameters if batch_ndim: # jax.random.permutation has no concept of batch dims - x_core_shape = x.shape[x_batch_ndim:] if size is None: - size = x.shape[:x_batch_ndim] + size = x.shape[:batch_ndim] else: - x = jax.numpy.broadcast_to(x, size + x_core_shape) + x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:]) batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:]) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index f35d1b3fcd..0091be99ae 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -347,7 +347,6 @@ def categorical_rv(rng, size, p): def numba_funcify_DirichletRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype alphas_ndim = op.dist_params(node)[0].type.ndim - neg_ind_shape_len = -alphas_ndim + 1 size_param = op.size_param(node) size_len = ( None @@ -363,11 +362,6 @@ def dirichlet_rv(rng, size, alphas): samples_shape = alphas.shape else: size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - if ( - 0 < alphas.ndim - 1 <= len(size_tpl) - and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1] - ): - raise ValueError("Parameters shape and size do not match.") samples_shape = size_tpl + alphas.shape[-1:] res = np.empty(samples_shape, dtype=out_dtype) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 5c5665ef1f..1702fd1e99 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -2002,6 +2002,11 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): a_shape = tuple(a.shape) if param_shapes is None else tuple(param_shapes[0]) a_batch_ndim = len(a_shape) - self.ndims_params[0] a_core_shape = a_shape[a_batch_ndim:] + core_shape_ndim = core_shape.type.ndim + if core_shape_ndim > 1: + # Batch core shapes are only valid if homogeneous or broadcasted, + # as otherwise they would imply ragged choice arrays + core_shape = core_shape[(0,) * (core_shape_ndim - 1)] return tuple(core_shape) + a_core_shape[1:] def rng_fn(self, *params): @@ -2011,15 +2016,11 @@ def rng_fn(self, *params): rng, a, core_shape, size = params p = None + if core_shape.ndim > 1: + core_shape = core_shape[(0,) * (core_shape.ndim - 1)] core_shape = tuple(core_shape) - # We don't have access to the node in rng_fn for easy computation of batch_ndim :( - a_batch_ndim = batch_ndim = a.ndim - self.ndims_params[0] - if p is not None: - p_batch_ndim = p.ndim - self.ndims_params[1] - batch_ndim = max(batch_ndim, p_batch_ndim) - size_ndim = 0 if size is None else len(size) - batch_ndim = max(batch_ndim, size_ndim) + batch_ndim = a.ndim - self.ndims_params[0] if batch_ndim == 0: # Numpy choice fails with size=() if a.ndim > 1 is batched @@ -2031,16 +2032,16 @@ def rng_fn(self, *params): # Numpy choice doesn't have a concept of batch dims if size is None: if p is None: - size = a.shape[:a_batch_ndim] + size = a.shape[:batch_ndim] else: size = np.broadcast_shapes( - a.shape[:a_batch_ndim], - p.shape[:p_batch_ndim], + a.shape[:batch_ndim], + p.shape[:batch_ndim], ) - a = np.broadcast_to(a, size + a.shape[a_batch_ndim:]) + a = np.broadcast_to(a, size + a.shape[batch_ndim:]) if p is not None: - p = np.broadcast_to(p, size + p.shape[p_batch_ndim:]) + p = np.broadcast_to(p, size + p.shape[batch_ndim:]) a_indexed_shape = a.shape[len(size) + 1 :] out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype) @@ -2143,26 +2144,26 @@ class PermutationRV(RandomVariable): def _supp_shape_from_params(self, dist_params, param_shapes=None): [x] = dist_params x_shape = tuple(x.shape if param_shapes is None else param_shapes[0]) - if x.type.ndim == 0: - return (x,) + if self.ndims_params[0] == 0: + # Implicit arange, this is only valid for homogeneous arrays + # Otherwise it would imply a ragged permutation array. + return (x.ravel()[0],) else: batch_x_ndim = x.type.ndim - self.ndims_params[0] return x_shape[batch_x_ndim:] def rng_fn(self, rng, x, size): # We don't have access to the node in rng_fn :( - x_batch_ndim = x.ndim - self.ndims_params[0] - batch_ndim = max(x_batch_ndim, 0 if size is None else len(size)) + batch_ndim = x.ndim - self.ndims_params[0] if batch_ndim: # rng.permutation has no concept of batch dims - x_core_shape = x.shape[x_batch_ndim:] if size is None: - size = x.shape[:x_batch_ndim] + size = x.shape[:batch_ndim] else: - x = np.broadcast_to(x, size + x_core_shape) + x = np.broadcast_to(x, size + x.shape[batch_ndim:]) - out = np.empty(size + x_core_shape, dtype=x.dtype) + out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype) for idx in np.ndindex(size): out[idx] = rng.permutation(x[idx]) return out diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 7b73c9de03..5a31448f68 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -9,7 +9,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op -from pytensor.graph.replace import _vectorize_node, vectorize_graph +from pytensor.graph.replace import _vectorize_node from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( @@ -359,6 +359,12 @@ def make_node(self, rng, size, *dist_params): inferred_shape = self._infer_shape(size, dist_params) _, static_shape = infer_static_shape(inferred_shape) + dist_params = explicit_expand_dims( + dist_params, + self.ndims_params, + size_length=None if NoneConst.equals(size) else get_vector_length(size), + ) + inputs = (rng, size, *dist_params) out_type = TensorType(dtype=self.dtype, shape=static_shape) outputs = (rng.type(), out_type()) @@ -459,22 +465,14 @@ def vectorize_random_variable( None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size) ) - original_expanded_dist_params = explicit_expand_dims( - original_dist_params, op.ndims_params, len_old_size - ) - # We call vectorize_graph to automatically handle any new explicit expand_dims - dist_params = vectorize_graph( - original_expanded_dist_params, dict(zip(original_dist_params, dist_params)) - ) - - new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim - - if new_ndim and len_old_size and equal_computations([old_size], [size]): + if len_old_size and equal_computations([old_size], [size]): # If the original RV had a size variable and a new one has not been provided, # we need to define a new size as the concatenation of the original size dimensions # and the novel ones implied by new broadcasted batched parameters dimensions. - broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params) - new_size_dims = broadcasted_batch_shape[:new_ndim] - size = concatenate([new_size_dims, size]) + new_ndim = dist_params[0].type.ndim - original_dist_params[0].type.ndim + if new_ndim >= 0: + new_size = compute_batch_shape(dist_params, ndims_params=op.ndims_params) + new_size_dims = new_size[:new_ndim] + size = concatenate([new_size_dims, size]) return op.make_node(rng, size, *dist_params) diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index 21ddb32af9..ef68235889 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -1,6 +1,7 @@ import re from pytensor.compile import optdb +from pytensor.graph import Constant from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.db import SequenceDB from pytensor.tensor import abs as abs_t @@ -159,12 +160,17 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): return None rng, size, a_scalar_param, *other_params = node.inputs - if a_scalar_param.type.ndim > 0: + if not all(a_scalar_param.type.broadcastable): # Automatic vectorization could have made this parameter batched, # there is no nice way to materialize a batched arange return None - a_vector_param = arange(a_scalar_param) + # We need to try and do an eager squeeze here because arange will fail in jax + # if there is an array leading to it, even if it's constant + if isinstance(a_scalar_param, Constant): + a_scalar_param = a_scalar_param.data + a_vector_param = arange(a_scalar_param.squeeze()) + new_props_dict = op._props_dict().copy() # Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)" # I.e., we substitute the first `()` by `(a)` diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index e8e4fc2dbf..71639244b2 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -28,6 +28,9 @@ rng = np.random.default_rng(42849) +@pytest.mark.xfail( + reason="Most RVs are not working correctly with explicit expand_dims" +) @pytest.mark.parametrize( "rv_op, dist_args, size", [ @@ -388,6 +391,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ) +@pytest.mark.xfail(reason="Test is not working correctly with explicit expand_dims") @pytest.mark.parametrize( "rv_op, dist_args, base_size, cdf_name, params_conv", [ @@ -633,7 +637,7 @@ def test_CategoricalRV(dist_args, size, cm): ), ), (10, 4), - pytest.raises(ValueError, match="Parameters shape.*"), + pytest.raises(ValueError, match="operands could not be broadcast together"), ), ], ) @@ -658,6 +662,7 @@ def test_DirichletRV(a, size, cm): assert np.allclose(res, exp_res, atol=1e-4) +@pytest.mark.xfail(reason="RandomState is not aligned with explicit expand_dims") def test_RandomState_updates(): rng = shared(np.random.RandomState(1)) rng_new = shared(np.random.RandomState(2)) diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 05f9fa17a4..e71329d349 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -796,13 +796,21 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): rng, ) + def is_subtensor_or_dimshuffle_subtensor(inp) -> bool: + subtensor_ops = Subtensor | AdvancedSubtensor | AdvancedSubtensor1 + if isinstance(inp.owner.op, subtensor_ops): + return True + if isinstance(inp.owner.op, DimShuffle): + return isinstance(inp.owner.inputs[0].owner.op, subtensor_ops) + return False + if lifted: assert isinstance(new_out.owner.op, RandomVariable) assert all( - isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor) + is_subtensor_or_dimshuffle_subtensor(i) for i in new_out.owner.op.dist_params(new_out.owner) if i.owner - ) + ), new_out.dprint(depth=3, print_type=True) else: assert isinstance( new_out.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor From 6fb9a8479ed5b7df5d44131ff69f596c22300774 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 22:13:35 +0200 Subject: [PATCH 13/15] Adapt Numba vectorize iterator for RandomVariables Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Co-authored-by: Adrian Seyboldt --- pytensor/link/numba/dispatch/basic.py | 10 +- pytensor/link/numba/dispatch/elemwise.py | 11 +- .../link/numba/dispatch/vectorize_codegen.py | 470 ++++++++++++------ 3 files changed, 332 insertions(+), 159 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index affdc82c3e..ec1c520663 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs): kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) - # Supress caching warnings + # Suppress cache warning for internal functions + # We have to add an ansi escape code for optional bold text by numba warnings.filterwarnings( "ignore", - message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals', + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + "Cannot cache compiled function " + '"(numba_funcified_fgraph|store_core_outputs)" ' + "as it uses dynamic globals" + ), category=NumbaWarning, ) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index d73e1bf73d..e2cc183e67 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -24,6 +24,7 @@ _jit_options, _vectorized, encode_literals, + store_core_outputs, ) from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( @@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): **kwargs, ) + nin = len(node.inputs) + nout = len(node.outputs) + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) output_dtypes = tuple(out.type.dtype for out in node.outputs) inplace_pattern = tuple(op.inplace_pattern.items()) + core_output_shapes = tuple(() for _ in range(nout)) # numba doesn't support nested literals right now... input_bc_patterns_enc = encode_literals(input_bc_patterns) @@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): def elemwise_wrapper(*inputs): return _vectorized( - scalar_op_fn, + core_op_fn, input_bc_patterns_enc, output_bc_patterns_enc, output_dtypes_enc, inplace_pattern_enc, + (), # constant_inputs inputs, + core_output_shapes, # core_shapes + None, # size ) # Pure python implementation, that will be used in tests diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 14c846c4e4..6eb6cab2c1 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -2,8 +2,9 @@ import base64 import pickle -from collections.abc import Sequence -from typing import Any +from collections.abc import Callable, Sequence +from textwrap import indent +from typing import Any, cast import numba import numpy as np @@ -11,13 +12,54 @@ from numba import TypingError, types from numba.core import cgutils from numba.core.base import BaseContext +from numba.core.types.misc import NoneType from numba.np import arrayobj +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.utils import compile_function_src + def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() +def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: + """Create a Numba function that wraps a core function and stores its vectorized outputs. + + @njit + def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): + to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) + o0[...] = to0 + o1[...] = to1 + ... + on[...] = ton + + """ + inputs = [f"i{i}" for i in range(nin)] + outputs = [f"o{i}" for i in range(nout)] + inner_outputs = [f"t{output}" for output in outputs] + + inp_signature = ", ".join(inputs) + out_signature = ", ".join(outputs) + inner_out_signature = ", ".join(inner_outputs) + store_outputs = "\n".join( + [ + f"{output}[...] = {inner_output}" + for output, inner_output in zip(outputs, inner_outputs) + ] + ) + func_src = f""" +def store_core_outputs({inp_signature}, {out_signature}): + {inner_out_signature} = core_op_fn({inp_signature}) +{indent(store_outputs, " " * 4)} +""" + global_env = {"core_op_fn": core_op_fn} + func = compile_function_src( + func_src, "store_core_outputs", {**globals(), **global_env} + ) + return cast(Callable, numba_basic.numba_njit(func)) + + _jit_options = { "fastmath": { "arcp", # Allow Reciprocal @@ -39,7 +81,10 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, - inputs, + constant_inputs_types, + input_types, + output_core_shape_types, + size_type, ): arg_types = [ scalar_func, @@ -47,7 +92,10 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, - inputs, + constant_inputs_types, + input_types, + output_core_shape_types, + size_type, ] if not isinstance(input_bc_patterns, types.Literal): @@ -70,34 +118,82 @@ def _vectorized( inplace_pattern = inplace_pattern.literal_value inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - n_outputs = len(output_bc_patterns) + batch_ndim = len(input_bc_patterns[0]) + nin = len(constant_inputs_types) + len(input_types) + nout = len(output_bc_patterns) + + if nin == 0: + raise TypingError("Empty argument list to vectorized op.") + + if nout == 0: + raise TypingError("Empty list of outputs for vectorized op.") - if not len(inputs) > 0: - raise TypingError("Empty argument list to elemwise op.") + if not all(isinstance(input, types.Array) for input in input_types): + raise TypingError("Vectorized inputs must be arrays.") - if not n_outputs > 0: - raise TypingError("Empty list of outputs for elemwise op.") + if not all( + len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns + ): + raise TypingError( + "Vectorized broadcastable patterns must have the same length." + ) + + core_input_types = [] + for input_type, bc_pattern in zip(input_types, input_bc_patterns): + core_ndim = input_type.ndim - len(bc_pattern) + # TODO: Reconsider this + if core_ndim == 0: + core_input_type = input_type.dtype + else: + core_input_type = types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) + core_input_types.append(core_input_type) - if not all(isinstance(input, types.Array) for input in inputs): - raise TypingError("Inputs to elemwise must be arrays.") - ndim = inputs[0].ndim + core_out_types = [ + types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") + for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + ] - if not all(input.ndim == ndim for input in inputs): - raise TypingError("Inputs to elemwise must have the same rank.") + out_types = [ + types.Array( + numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C" + ) + for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + ] - if not all(len(pattern) == ndim for pattern in output_bc_patterns): - raise TypingError("Invalid output broadcasting pattern.") + for output_idx, input_idx in inplace_pattern: + output_type = input_types[input_idx] + core_out_types[output_idx] = types.Array( + dtype=output_type.dtype, + ndim=output_type.ndim - batch_ndim, + layout=input_type.layout, + ) + out_types[output_idx] = output_type - scalar_signature = typingctx.resolve_function_type( - scalar_func, [in_type.dtype for in_type in inputs], {} + core_signature = typingctx.resolve_function_type( + scalar_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, ) + ret_type = types.Tuple(out_types) + + if len(output_dtypes) == 1: + ret_type = ret_type.types[0] + sig = ret_type(*arg_types) + # So we can access the constant values in codegen... input_bc_patterns_val = input_bc_patterns output_bc_patterns_val = output_bc_patterns output_dtypes_val = output_dtypes inplace_pattern_val = inplace_pattern - input_types = inputs + input_types = input_types + size_is_none = isinstance(size_type, NoneType) def codegen( ctx, @@ -105,8 +201,16 @@ def codegen( sig, args, ): - [_, _, _, _, _, inputs] = args + [_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args + + constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) inputs = cgutils.unpack_tuple(builder, inputs) + output_core_shapes = [ + cgutils.unpack_tuple(builder, shape) + for shape in cgutils.unpack_tuple(builder, output_core_shapes) + ] + size = None if size_is_none else cgutils.unpack_tuple(builder, size) + inputs = [ arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs) @@ -118,6 +222,7 @@ def codegen( builder, in_shapes, input_bc_patterns_val, + size, ) outputs, output_types = make_outputs( @@ -129,6 +234,7 @@ def codegen( inplace_pattern_val, inputs, input_types, + output_core_shapes, ) make_loop_call( @@ -136,8 +242,9 @@ def codegen( ctx, builder, scalar_func, - scalar_signature, + core_signature, iter_shape, + constant_inputs, inputs, outputs, input_bc_patterns_val, @@ -162,69 +269,94 @@ def codegen( builder, sig.return_type, [out._getvalue() for out in outputs] ) - ret_types = [ - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") - for dtype in output_dtypes - ] - - for output_idx, input_idx in inplace_pattern: - ret_types[output_idx] = input_types[input_idx] - - ret_type = types.Tuple(ret_types) - - if len(output_dtypes) == 1: - ret_type = ret_type.types[0] - sig = ret_type(*arg_types) - return sig, codegen def compute_itershape( ctx: BaseContext, builder: ir.IRBuilder, - in_shapes: tuple[ir.Instruction, ...], + in_shapes: list[list[ir.Instruction]], broadcast_pattern: tuple[tuple[bool, ...], ...], + size: list[ir.Instruction] | None, ): one = ir.IntType(64)(1) - ndim = len(in_shapes[0]) - shape = [None] * ndim - for i in range(ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): - length = in_shape[i] - if bc[i]: - with builder.if_then( - builder.icmp_unsigned("!=", length, one), likely=False - ): - msg = ( - f"Input {j} to elemwise is expected to have shape 1 in axis {i}" - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - elif shape[i] is not None: - with builder.if_then( - builder.icmp_unsigned("!=", length, shape[i]), likely=False - ): - with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( - then, - otherwise, + batch_ndim = len(broadcast_pattern[0]) + shape = [None] * batch_ndim + if size is not None: + shape = size + for i in range(batch_ndim): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + length = in_shape[i] + if bc[i]: + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}" + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + else: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False + ): + with builder.if_else( + builder.icmp_unsigned("==", length, one) + ) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompatible vectorized shapes for input {j} and axis {i}. " + f"Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + with otherwise: + msg = f"Vectorized input {j} has an incompatible shape in axis {i}." + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + else: + # Size is implied by the broadcast pattern + for i in range(batch_ndim): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + length = in_shape[i] + if bc[i]: + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}" + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + elif shape[i] is not None: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False ): - with then: - msg = ( - f"Incompatible shapes for input {j} and axis {i} of " - f"elemwise. Input {j} has shape 1, but is not statically " - "known to have shape 1, and thus not broadcastable." - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - with otherwise: - msg = ( - f"Input {j} to elemwise has an incompatible " - f"shape in axis {i}." - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - else: - shape[i] = length - for i in range(ndim): - if shape[i] is None: - shape[i] = one + with builder.if_else( + builder.icmp_unsigned("==", length, one) + ) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompatible vectorized shapes for input {j} and axis {i}. " + f"Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + with otherwise: + msg = f"Vectorized input {j} has an incompatible shape in axis {i}." + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + else: + shape[i] = length + for i in range(batch_ndim): + if shape[i] is None: + shape[i] = one return shape @@ -237,27 +369,32 @@ def make_outputs( inplace: tuple[tuple[int, int], ...], inputs: tuple[Any, ...], input_types: tuple[Any, ...], -): - arrays = [] - ar_types: list[types.Array] = [] + output_core_shapes: tuple, +) -> tuple[list[ir.Value], list[types.Array]]: + output_arrays = [] + output_arry_types = [] one = ir.IntType(64)(1) inplace_dict = dict(inplace) - for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): + for i, (core_shape, bc, dtype) in enumerate( + zip(output_core_shapes, out_bc, dtypes) + ): if i in inplace_dict: - arrays.append(inputs[inplace_dict[i]]) - ar_types.append(input_types[inplace_dict[i]]) + output_arrays.append(inputs[inplace_dict[i]]) + output_arry_types.append(input_types[inplace_dict[i]]) # We need to incref once we return the inplace objects continue dtype = numba.from_dtype(np.dtype(dtype)) - arrtype = types.Array(dtype, len(iter_shape), "C") - ar_types.append(arrtype) + output_ndim = len(iter_shape) + len(core_shape) + arrtype = types.Array(dtype, output_ndim, "C") + output_arry_types.append(arrtype) # This is actually an internal numba function, I guess we could # call `numba.nd.unsafe.ndarray` instead? - shape = [ + batch_shape = [ length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) ] + shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) - arrays.append(array) + output_arrays.append(array) # If there is no inplace operation, we know that all output arrays # don't alias. Informing llvm can make it easier to vectorize. @@ -265,7 +402,7 @@ def make_outputs( # The first argument is the output pointer arg = builder.function.args[0] arg.add_attribute("noalias") - return arrays, ar_types + return output_arrays, output_arry_types def make_loop_call( @@ -275,6 +412,7 @@ def make_loop_call( scalar_func: Any, scalar_signature: types.FunctionType, iter_shape: tuple[ir.Instruction, ...], + constant_inputs: tuple[ir.Instruction, ...], inputs: tuple[ir.Instruction, ...], outputs: tuple[ir.Instruction, ...], input_bc: tuple[tuple[bool, ...], ...], @@ -283,18 +421,8 @@ def make_loop_call( output_types: tuple[Any, ...], ): safe = (False, False) - n_outputs = len(outputs) - - # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) - # Extract shape and stride information from the array. - # For later use in the loop body to do the indexing - def extract_array(aryty, obj): - shape = cgutils.unpack_tuple(builder, obj.shape) - strides = cgutils.unpack_tuple(builder, obj.strides) - data = obj.data - layout = aryty.layout - return (data, shape, strides, layout) + n_outputs = len(outputs) # TODO I think this is better than the noalias attribute # for the input, but self_ref isn't supported in a released @@ -306,12 +434,6 @@ def extract_array(aryty, obj): # input_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope]) - inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs)) - - outputs = tuple( - extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs) - ) - zero = ir.Constant(ir.IntType(64), 0) # Setup loops and initialize accumulators for outputs @@ -338,69 +460,105 @@ def extract_array(aryty, obj): # Load values from input arrays input_vals = [] - for array_info, bc in zip(inputs, input_bc): - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] - ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) - val = builder.load(ptr) - # val.set_metadata("alias.scope", input_scope_set) - # val.set_metadata("noalias", output_scope_set) + for input, input_type, bc in zip(inputs, input_types, input_bc): + core_ndim = input_type.ndim - len(bc) + + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + zero + ] * core_ndim + ptr = cgutils.get_item_pointer2( + context, + builder, + input.data, + cgutils.unpack_tuple(builder, input.shape), + cgutils.unpack_tuple(builder, input.strides), + input_type.layout, + idxs_bc, + *safe, + ) + if core_ndim == 0: + # Retrive scalar item at index + val = builder.load(ptr) + # val.set_metadata("alias.scope", input_scope_set) + # val.set_metadata("noalias", output_scope_set) + else: + # Retrieve array item at index + # This is a streamlined version of Numba's `GUArrayArg.load` + # TODO check layout arg! + core_arry_type = types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) + core_array = context.make_array(core_arry_type)(context, builder) + core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:] + core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:] + itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype)) + context.populate_array( + core_array, + # TODO whey do we need to bitcast? + data=builder.bitcast(ptr, core_array.data.type), + shape=cgutils.pack_array(builder, core_shape), + strides=cgutils.pack_array(builder, core_strides), + itemsize=context.get_constant(types.intp, itemsize), + # TODO what is meminfo about? + meminfo=None, + ) + val = core_array._getvalue() + input_vals.append(val) + # Create output slices to pass to inner func + output_slices = [] + for output, output_type, bc in zip(outputs, output_types, output_bc): + core_ndim = output_type.ndim - len(bc) + size_type = output.shape.type.element # type: ignore + output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore + output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore + + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + zero + ] * core_ndim + ptr = cgutils.get_item_pointer2( + context, + builder, + output.data, # type:ignore + output_shape, + output_strides, + output_type.layout, + idxs_bc, + *safe, + ) + + # Retrieve array item at index + # This is a streamlined version of Numba's `GUArrayArg.load` + core_arry_type = types.Array( + dtype=output_type.dtype, ndim=core_ndim, layout=output_type.layout + ) + core_array = context.make_array(core_arry_type)(context, builder) + core_shape = output_shape[-core_ndim:] if core_ndim > 0 else [] + core_strides = output_strides[-core_ndim:] if core_ndim > 0 else [] + itemsize = context.get_abi_sizeof(context.get_data_type(output_type.dtype)) + context.populate_array( + core_array, + # TODO whey do we need to bitcast? + data=builder.bitcast(ptr, core_array.data.type), + shape=cgutils.pack_array(builder, core_shape, ty=size_type), + strides=cgutils.pack_array(builder, core_strides, ty=size_type), + itemsize=context.get_constant(types.intp, itemsize), + # TODO what is meminfo about? + meminfo=None, + ) + val = core_array._getvalue() + output_slices.append(val) + inner_codegen = context.get_function(scalar_func, scalar_signature) if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple): input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] - output_values = inner_codegen(builder, input_vals) - if isinstance(scalar_signature.return_type, types.Tuple | types.UniTuple): - output_values = cgutils.unpack_tuple(builder, output_values) - func_output_types = scalar_signature.return_type.types - else: - output_values = [output_values] - func_output_types = [scalar_signature.return_type] - - # Update output value or accumulators respectively - for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)): - if accu is not None: - load = builder.load(accu) - # load.set_metadata("alias.scope", output_scope_set) - # load.set_metadata("noalias", input_scope_set) - new_value = builder.fadd(load, value) - builder.store(new_value, accu) - # TODO belongs to noalias scope - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) - else: - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])] - ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) - # store = builder.store(value, ptr) - value = context.cast( - builder, value, func_output_types[i], output_types[i].dtype - ) - arrayobj.store_item(context, builder, output_types[i], value, ptr) - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) + inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices]) - # Close the loops and write accumulator values to the output arrays + # Close the loops for depth, loop in enumerate(loop_stack[::-1]): - for output, (accu, accu_depth) in enumerate(output_accumulator): - if accu_depth == depth: - idxs_bc = [ - zero if bc else idx for idx, bc in zip(idxs, output_bc[output]) - ] - ptr = cgutils.get_item_pointer2( - context, builder, *outputs[output], idxs_bc - ) - load = builder.load(accu) - # load.set_metadata("alias.scope", output_scope_set) - # load.set_metadata("noalias", input_scope_set) - # store = builder.store(load, ptr) - load = context.cast( - builder, load, func_output_types[output], output_types[output].dtype - ) - arrayobj.store_item(context, builder, output_types[output], load, ptr) - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) loop.__exit__(None, None, None) return From 2bc400b1e18d44350113f39cf41932521deeaded Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 10 May 2024 12:55:49 +0200 Subject: [PATCH 14/15] Add support for RandomVariable with Generators in Numba backend and drop support for RandomState Co-authored-by: Adrian Seyboldt Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/compile/builders.py | 8 +- pytensor/compile/mode.py | 2 +- pytensor/link/numba/dispatch/basic.py | 6 + pytensor/link/numba/dispatch/random.py | 692 +++++++++---------- pytensor/link/numba/dispatch/scan.py | 6 +- pytensor/tensor/blockwise.py | 5 + pytensor/tensor/random/basic.py | 5 +- pytensor/tensor/random/op.py | 9 + pytensor/tensor/random/rewriting/__init__.py | 3 +- pytensor/tensor/random/rewriting/numba.py | 88 +++ pytensor/tensor/rewriting/shape.py | 4 +- scripts/mypy-failing.txt | 1 - tests/link/numba/test_basic.py | 29 +- tests/link/numba/test_random.py | 470 ++++++------- tests/link/numba/test_scan.py | 6 +- tests/tensor/random/test_basic.py | 16 +- 16 files changed, 689 insertions(+), 661 deletions(-) create mode 100644 pytensor/tensor/random/rewriting/numba.py diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 7d4f7e39f3..600d1a0c5f 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -27,7 +27,6 @@ from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.utils import MissingInputError -from pytensor.tensor.rewriting.shape import ShapeFeature def infer_shape(outs, inputs, input_shapes): @@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes): # inside. We don't use the full ShapeFeature interface, but we # let it initialize itself with an empty fgraph, otherwise we will # need to do it manually + + # TODO: ShapeFeature should live elsewhere + from pytensor.tensor.rewriting.shape import ShapeFeature + for inp, inp_shp in zip(inputs, input_shapes): if inp_shp is not None and len(inp_shp) != inp.type.ndim: assert len(inp_shp) == inp.type.ndim @@ -307,6 +310,7 @@ def __init__( connection_pattern: list[list[bool]] | None = None, strict: bool = False, name: str | None = None, + destroy_map: dict[int, tuple[int, ...]] | None = None, **kwargs, ): """ @@ -464,6 +468,7 @@ def __init__( if name is not None: assert isinstance(name, str), "name must be None or string object" self.name = name + self.destroy_map = destroy_map if destroy_map is not None else {} def __eq__(self, other): # TODO: recognize a copy @@ -862,6 +867,7 @@ def make_node(self, *inputs): rop_overrides=self.rop_overrides, connection_pattern=self._connection_pattern, name=self.name, + destroy_map=self.destroy_map, **self.kwargs, ) new_inputs = ( diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index cf0b058814..cf8dd9e73e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -463,7 +463,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): NUMBA = Mode( NumbaLinker(), RewriteDatabaseQuery( - include=["fast_run"], + include=["fast_run", "numba"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], ), ) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ec1c520663..4e9830d627 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -18,6 +18,7 @@ from numba.extending import box, overload from pytensor import config +from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import Apply @@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) + # Apply inner rewrites + # TODO: Not sure this is the right place to do this, should we have a rewrite that + # explicitly triggers the optimization of the inner graphs of OpFromGraph? + # The C-code defers it to the make_thunk phase + NUMBA.optimizer(op.fgraph) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) if len(op.fgraph.outputs) == 1: diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 0091be99ae..c7e2f24546 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,442 +1,400 @@ from collections.abc import Callable -from textwrap import dedent, indent -from typing import Any +from copy import copy +from functools import singledispatch +from textwrap import dedent +import numba import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -from numba import _helperlib, types -from numba.core import cgutils -from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox -from numpy.random import RandomState +from numba import types +from numba.core.extending import overload import pytensor.tensor.random.basic as ptr -from pytensor.graph.basic import Apply +from pytensor.graph import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify +from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _jit_options, + _vectorized, + encode_literals, + store_core_outputs, +) from pytensor.link.utils import ( compile_function_src, - get_name_for_object, - unique_name_generator, ) -from pytensor.tensor.basic import get_vector_length +from pytensor.tensor import get_vector_length +from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.random.type import RandomStateType from pytensor.tensor.type_other import NoneTypeT - - -class RandomStateNumbaType(types.Type): - def __init__(self): - super().__init__(name="RandomState") - - -random_state_numba_type = RandomStateNumbaType() - - -@typeof_impl.register(RandomState) -def typeof_index(val, c): - return random_state_numba_type - - -@register_model(RandomStateNumbaType) -class RandomStateNumbaModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - # TODO: We can add support for boxing and unboxing - # the attributes that describe a RandomState so that - # they can be accessed inside njit functions, if required. - ("state_key", types.Array(types.uint32, 1, "C")), - ] - models.StructModel.__init__(self, dmm, fe_type, members) - - -@unbox(RandomStateNumbaType) -def unbox_random_state(typ, obj, c): - """Convert a `RandomState` object to a native `RandomStateNumbaModel` structure. - - Note that this will create a 'fake' structure which will just get the - `RandomState` objects accepted in Numba functions but the actual information - of the Numba's random state is stored internally and can be accessed - anytime using ``numba._helperlib.rnd_get_np_state_ptr()``. +from pytensor.tensor.utils import _parse_gufunc_signature + + +@overload(copy) +def copy_NumPyRandomGenerator(rng): + def impl(rng): + # TODO: Open issue on Numba? + with numba.objmode(new_rng=types.npy_rng): + new_rng = copy(rng) + + return new_rng + + return impl + + +@singledispatch +def numba_core_rv_funcify(op: Op, node: Apply) -> Callable: + """Return the core function for a random variable operation.""" + raise NotImplementedError(f"Core implementation of {op} not implemented.") + + +@numba_core_rv_funcify.register(ptr.UniformRV) +@numba_core_rv_funcify.register(ptr.TriangularRV) +@numba_core_rv_funcify.register(ptr.BetaRV) +@numba_core_rv_funcify.register(ptr.NormalRV) +@numba_core_rv_funcify.register(ptr.LogNormalRV) +@numba_core_rv_funcify.register(ptr.GammaRV) +@numba_core_rv_funcify.register(ptr.ExponentialRV) +@numba_core_rv_funcify.register(ptr.WeibullRV) +@numba_core_rv_funcify.register(ptr.LogisticRV) +@numba_core_rv_funcify.register(ptr.VonMisesRV) +@numba_core_rv_funcify.register(ptr.PoissonRV) +@numba_core_rv_funcify.register(ptr.GeometricRV) +# @numba_core_rv_funcify.register(ptr.HyperGeometricRV) # Not implemented in numba +@numba_core_rv_funcify.register(ptr.WaldRV) +@numba_core_rv_funcify.register(ptr.LaplaceRV) +@numba_core_rv_funcify.register(ptr.BinomialRV) +@numba_core_rv_funcify.register(ptr.NegBinomialRV) +@numba_core_rv_funcify.register(ptr.MultinomialRV) +@numba_core_rv_funcify.register(ptr.PermutationRV) +@numba_core_rv_funcify.register(ptr.IntegersRV) +def numba_core_rv_default(op, node): + """Create a default RV core numba function. + + @njit + def random(rng, i0, i1, ..., in): + return rng.name(i0, i1, ..., in) """ - interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) - is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) - return NativeValue(interval._getvalue(), is_error=is_error) - - -@box(RandomStateNumbaType) -def box_random_state(typ, val, c): - """Convert a native `RandomStateNumbaModel` structure to an `RandomState` object - using Numba's internal state array. - - Note that `RandomStateNumbaModel` is just a placeholder structure with no - inherent information about Numba internal random state, all that information - is instead retrieved from Numba using ``_helperlib.rnd_get_state()`` and a new - `RandomState` is constructed using the Numba's current internal state. - """ - pos, state_list = _helperlib.rnd_get_state(_helperlib.rnd_get_np_state_ptr()) - rng = RandomState() - rng.set_state(("MT19937", state_list, pos)) - class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(rng)) - return class_obj - + name = op.name -@numba_typify.register(RandomState) -def numba_typify_RandomState(state, **kwargs): - # The numba_typify in this case is just an passthrough function - # that synchronizes Numba's internal random state with the current - # RandomState object - ints, index = state.get_state()[1:3] - ptr = _helperlib.rnd_get_np_state_ptr() - _helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints])) - return state + inputs = [f"i{i}" for i in range(len(op.ndims_params))] + input_signature = ",".join(inputs) + func_src = dedent(f""" + def {name}(rng, {input_signature}): + return rng.{name}({input_signature}) + """) -def make_numba_random_fn(node, np_random_func): - """Create Numba implementations for existing Numba-supported ``np.random`` functions. + func = compile_function_src(func_src, name, {**globals()}) + return numba_basic.numba_njit(func) - The functions generated here add parameter broadcasting and the ``size`` - argument to the Numba-supported scalar ``np.random`` functions. - """ - op: ptr.RandomVariable = node.op - rng_param = op.rng_param(node) - if not isinstance(rng_param.type, RandomStateType): - raise TypeError("Numba does not support NumPy `Generator`s") - - size_param = op.size_param(node) - size_len = ( - None - if isinstance(size_param.type, NoneTypeT) - else int(get_vector_length(size_param)) - ) - dist_params = op.dist_params(node) - - # Make a broadcast-capable version of the Numba supported scalar sampling - # function - bcast_fn_name = f"pytensor_random_{get_name_for_object(np_random_func)}" - - sized_fn_name = "sized_random_variable" - - unique_names = unique_name_generator( - [ - bcast_fn_name, - sized_fn_name, - "np", - "np_random_func", - "numba_vectorize", - "to_fixed_tuple", - "size_len", - "size_dims", - "rng", - "size", - ], - suffix_sep="_", - ) - bcast_fn_input_names = ", ".join( - [unique_names(i, force_unique=True) for i in dist_params] - ) - bcast_fn_global_env = { - "np_random_func": np_random_func, - "numba_vectorize": numba_basic.numba_vectorize, - } - - bcast_fn_src = f""" -@numba_vectorize -def {bcast_fn_name}({bcast_fn_input_names}): - return np_random_func({bcast_fn_input_names}) - """ - bcast_fn = compile_function_src( - bcast_fn_src, bcast_fn_name, {**globals(), **bcast_fn_global_env} - ) - - random_fn_input_names = ", ".join( - ["rng", "size"] + [unique_names(i) for i in dist_params] - ) - - # Now, create a Numba JITable function that implements the `size` parameter +@numba_core_rv_funcify.register(ptr.BernoulliRV) +def numba_core_BernoulliRV(op, node): out_dtype = node.outputs[1].type.numpy_dtype - random_fn_global_env = { - bcast_fn_name: bcast_fn, - "out_dtype": out_dtype, - } - if size_len is not None: - size_dims = size_len - max(i.ndim for i in dist_params) - - random_fn_body = dedent( - f""" - size = to_fixed_tuple(size, size_len) + @numba_basic.numba_njit() + def random(rng, p): + return ( + direct_cast(0, out_dtype) + if p < rng.uniform() + else direct_cast(1, out_dtype) + ) - data = np.empty(size, dtype=out_dtype) - for i in np.ndindex(size[:size_dims]): - data[i] = {bcast_fn_name}({bcast_fn_input_names}) + return random - """ - ) - random_fn_global_env.update( - { - "np": np, - "to_fixed_tuple": numba_ndarray.to_fixed_tuple, - "size_len": size_len, - "size_dims": size_dims, - } - ) - else: - random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})""" - sized_fn_src = dedent( - f""" -def {sized_fn_name}({random_fn_input_names}): -{indent(random_fn_body, " " * 4)} - return (rng, data) - """ - ) - random_fn = compile_function_src( - sized_fn_src, sized_fn_name, {**globals(), **random_fn_global_env} - ) - random_fn = numba_basic.numba_njit(random_fn) +@numba_core_rv_funcify.register(ptr.HalfNormalRV) +def numba_core_HalfNormalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, loc, scale): + return loc + scale * np.abs(rng.standard_normal()) return random_fn -@numba_funcify.register(ptr.UniformRV) -@numba_funcify.register(ptr.TriangularRV) -@numba_funcify.register(ptr.BetaRV) -@numba_funcify.register(ptr.NormalRV) -@numba_funcify.register(ptr.LogNormalRV) -@numba_funcify.register(ptr.GammaRV) -@numba_funcify.register(ptr.ParetoRV) -@numba_funcify.register(ptr.GumbelRV) -@numba_funcify.register(ptr.ExponentialRV) -@numba_funcify.register(ptr.WeibullRV) -@numba_funcify.register(ptr.LogisticRV) -@numba_funcify.register(ptr.VonMisesRV) -@numba_funcify.register(ptr.PoissonRV) -@numba_funcify.register(ptr.GeometricRV) -@numba_funcify.register(ptr.HyperGeometricRV) -@numba_funcify.register(ptr.WaldRV) -@numba_funcify.register(ptr.LaplaceRV) -@numba_funcify.register(ptr.BinomialRV) -@numba_funcify.register(ptr.MultinomialRV) -@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported -@numba_funcify.register(ptr.PermutationRV) -def numba_funcify_RandomVariable(op, node, **kwargs): - name = op.name - np_random_func = getattr(np.random, name) - - return make_numba_random_fn(node, np_random_func) +@numba_core_rv_funcify.register(ptr.CauchyRV) +def numba_core_CauchyRV(op, node): + @numba_basic.numba_njit + def random(rng, loc, scale): + return (loc + rng.standard_cauchy()) / scale + return random -def create_numba_random_fn( - op: Op, - node: Apply, - scalar_fn: Callable[[str], str], - global_env: dict[str, Any] | None = None, -) -> Callable: - """Create a vectorized function from a callable that generates the ``str`` function body. - TODO: This could/should be generalized for other simple function - construction cases that need unique-ified symbol names. - """ - np_random_fn_name = f"pytensor_random_{get_name_for_object(op.name)}" +@numba_core_rv_funcify.register(ptr.ParetoRV) +def numba_core_ParetoRV(op, node): + @numba_basic.numba_njit + def random(rng, b, scale): + # Follows scipy implementation + U = rng.random() + return np.power(1 - U, -1 / b) * scale - if global_env: - np_global_env = global_env.copy() - else: - np_global_env = {} + return random - np_global_env["np"] = np - np_global_env["numba_vectorize"] = numba_basic.numba_vectorize - unique_names = unique_name_generator( - [np_random_fn_name, *np_global_env.keys(), "rng", "size"], - suffix_sep="_", - ) +@numba_core_rv_funcify.register(ptr.CategoricalRV) +def core_CategoricalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, p): + unif_sample = rng.uniform(0, 1) + return np.searchsorted(np.cumsum(p), unif_sample) - dist_params = op.dist_params(node) - np_names = [unique_names(i, force_unique=True) for i in dist_params] - np_input_names = ", ".join(np_names) - np_random_fn_src = f""" -@numba_vectorize -def {np_random_fn_name}({np_input_names}): -{scalar_fn(*np_names)} - """ - np_random_fn = compile_function_src( - np_random_fn_src, np_random_fn_name, {**globals(), **np_global_env} - ) + return random_fn - return make_numba_random_fn(node, np_random_fn) +@numba_core_rv_funcify.register(ptr.MvNormalRV) +def core_MvNormalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, mean, cov): + chol = np.linalg.cholesky(cov) + stdnorm = rng.normal(size=cov.shape[-1]) + return np.dot(chol, stdnorm) + mean -@numba_funcify.register(ptr.NegBinomialRV) -def numba_funcify_NegBinomialRV(op, node, **kwargs): - return make_numba_random_fn(node, np.random.negative_binomial) + random_fn.handles_out = True + return random_fn -@numba_funcify.register(ptr.CauchyRV) -def numba_funcify_CauchyRV(op, node, **kwargs): - def body_fn(loc, scale): - return f" return ({loc} + np.random.standard_cauchy()) / {scale}" +@numba_core_rv_funcify.register(ptr.DirichletRV) +def core_DirichletRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, alpha): + y = np.empty_like(alpha) + for i in range(len(alpha)): + y[i] = rng.gamma(alpha[i], 1.0) + return y / y.sum() - return create_numba_random_fn(op, node, body_fn) + return random_fn -@numba_funcify.register(ptr.HalfNormalRV) -def numba_funcify_HalfNormalRV(op, node, **kwargs): - def body_fn(a, b): - return f" return {a} + {b} * abs(np.random.normal(0, 1))" +@numba_core_rv_funcify.register(ptr.GumbelRV) +def core_GumbelRV(op, node): + """Code adapted from Numpy Implementation - return create_numba_random_fn(op, node, body_fn) + https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511 + """ + @numba_basic.numba_njit + def random_fn(rng, loc, scale): + U = 1.0 - rng.random() + if U < 1.0: + return loc - scale * np.log(-np.log(U)) + else: + return random_fn(rng, loc, scale) -@numba_funcify.register(ptr.BernoulliRV) -def numba_funcify_BernoulliRV(op, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype + return random_fn - def body_fn(a): - return f""" - if {a} < np.random.uniform(0, 1): - return direct_cast(0, out_dtype) - else: - return direct_cast(1, out_dtype) - """ - - return create_numba_random_fn( - op, - node, - body_fn, - {"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast}, - ) +@numba_core_rv_funcify.register(ptr.VonMisesRV) +def core_VonMisesRV(op, node): + """Code adapted from Numpy Implementation -@numba_funcify.register(ptr.CategoricalRV) -def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype - size_param = op.size_param(node) - size_len = ( - None - if isinstance(size_param.type, NoneTypeT) - else int(get_vector_length(size_param)) - ) - p_ndim = node.inputs[-1].ndim + https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925 + """ @numba_basic.numba_njit - def categorical_rv(rng, size, p): - if size_len is None: - size_tpl = p.shape[:-1] - else: - size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - p = np.broadcast_to(p, size_tpl + p.shape[-1:]) - - # Workaround https://github.com/numba/numba/issues/8975 - if size_len is None and p_ndim == 1: - unif_samples = np.asarray(np.random.uniform(0, 1)) + def random_fn(rng, mu, kappa): + if np.isnan(kappa): + return np.nan + if kappa < 1e-8: + # Use a uniform for very small values of kappa + return np.pi * (2 * rng.random() - 1) else: - unif_samples = np.random.uniform(0, 1, size_tpl) - - res = np.empty(size_tpl, dtype=out_dtype) - for idx in np.ndindex(*size_tpl): - res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx]) - - return (rng, res) + # with double precision rho is zero until 1.4e-8 + if kappa < 1e-5: + # second order taylor expansion around kappa = 0 + # precise until relatively large kappas as second order is 0 + s = 1.0 / kappa + kappa + else: + if kappa <= 1e6: + # Path for 1e-5 <= kappa <= 1e6 + r = 1 + np.sqrt(1 + 4 * kappa * kappa) + rho = (r - np.sqrt(2 * r)) / (2 * kappa) + s = (1 + rho * rho) / (2 * rho) + else: + # Fallback to wrapped normal distribution for kappa > 1e6 + result = mu + np.sqrt(1.0 / kappa) * rng.standard_normal() + # Ensure result is within bounds + if result < -np.pi: + result += 2 * np.pi + if result > np.pi: + result -= 2 * np.pi + return result + + while True: + U = rng.random() + Z = np.cos(np.pi * U) + W = (1 + s * Z) / (s + Z) + Y = kappa * (s - W) + V = rng.random() + # V == 0.0 is ok here since Y >= 0 always leads + # to accept, while Y < 0 always rejects + if (Y * (2 - Y) - V >= 0) or (np.log(Y / V) + 1 - Y >= 0): + break + + U = rng.random() + + result = np.arccos(W) + if U < 0.5: + result = -result + result += mu + neg = result < 0 + mod = np.abs(result) + mod = np.mod(mod + np.pi, 2 * np.pi) - np.pi + if neg: + mod *= -1 + + return mod - return categorical_rv + return random_fn -@numba_funcify.register(ptr.DirichletRV) -def numba_funcify_DirichletRV(op, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype - alphas_ndim = op.dist_params(node)[0].type.ndim - size_param = op.size_param(node) - size_len = ( - None - if isinstance(size_param.type, NoneTypeT) - else int(get_vector_length(size_param)) - ) +@numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement) +def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node): + [core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1] + core_shape_len = int(core_shape_len_sig) + implicit_arange = op.ndims_params[0] == 0 - if alphas_ndim > 1: + if op.has_p_param: @numba_basic.numba_njit - def dirichlet_rv(rng, size, alphas): - if size_len is None: - samples_shape = alphas.shape + def random_fn(rng, a, p, core_shape): + # Adapted from Numpy: https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L922-L941 + size = np.prod(core_shape) + core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) + if implicit_arange: + pop_size = a else: - size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - samples_shape = size_tpl + alphas.shape[-1:] - - res = np.empty(samples_shape, dtype=out_dtype) - alphas_bcast = np.broadcast_to(alphas, samples_shape) - - for index in np.ndindex(*samples_shape[:-1]): - res[index] = np.random.dirichlet(alphas_bcast[index]) - - return (rng, res) + pop_size = a.shape[0] + + if size > pop_size: + raise ValueError( + "Cannot take a larger sample than population without replacement" + ) + if np.count_nonzero(p > 0) < size: + raise ValueError("Fewer non-zero entries in p than size") + + p = p.copy() + n_uniq = 0 + idx = np.zeros(core_shape, dtype=np.int64) + flat_idx = idx.ravel() + while n_uniq < size: + x = rng.random((size - n_uniq,)) + # Set the probabilities of items that have already been found to 0 + p[flat_idx[:n_uniq]] = 0 + # Take new (unique) categorical draws from the remaining probabilities + cdf = np.cumsum(p) + cdf /= cdf[-1] + new = np.searchsorted(cdf, x, side="right") + + # Numba doesn't support return_index in np.unique + # _, unique_indices = np.unique(new, return_index=True) + # unique_indices.sort() + new.sort() + unique_indices = [ + idx + for idx, prev_item in enumerate(new[:-1], 1) + if new[idx] != prev_item + ] + unique_indices = np.array([0] + unique_indices) # noqa: RUF005 + + new = new[unique_indices] + flat_idx[n_uniq : n_uniq + new.size] = new + n_uniq += new.size + + if implicit_arange: + return idx + else: + # Numba doesn't support advanced indexing, so we ravel index and reshape + return a[idx.ravel()].reshape(core_shape + a.shape[1:]) else: @numba_basic.numba_njit - def dirichlet_rv(rng, size, alphas): - if size_len is not None: - size = numba_ndarray.to_fixed_tuple(size, size_len) - return (rng, np.random.dirichlet(alphas, size)) - - return dirichlet_rv - - -@numba_funcify.register(ptr.ChoiceWithoutReplacement) -def numba_funcify_choice_without_replacement(op, node, **kwargs): - batch_ndim = op.batch_ndim(node) - if batch_ndim: - # The code isn't too hard to write, but Numba doesn't support a with ndim > 1, - # and I don't want to change the batched tests for this - # We'll just raise an error for now - raise NotImplementedError( - "ChoiceWithoutReplacement with batch_ndim not supported in Numba backend" - ) + def random_fn(rng, a, core_shape): + # Until Numba supports generator.choice we use a poor implementation + # that permutates the whole arange array and takes the first `size` elements + # This is widely inefficient when size << a.shape[0] + size = np.prod(core_shape) + core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) + idx = rng.permutation(size)[:size] - [core_shape_len] = node.inputs[-1].type.shape + # Numba doesn't support advanced indexing so index on the flat dimension and reshape + # idx = idx.reshape(core_shape) + # if implicit_arange: + # return idx + # else: + # return a[idx] - if op.has_p_param: + if implicit_arange: + return idx.reshape(core_shape) + else: + return a[idx].reshape(core_shape + a.shape[1:]) - @numba_basic.numba_njit - def choice_without_replacement_rv(rng, size, a, p, core_shape): - core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) - samples = np.random.choice(a, size=core_shape, replace=False, p=p) - return (rng, samples) - else: + return random_fn - @numba_basic.numba_njit - def choice_without_replacement_rv(rng, size, a, core_shape): - core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) - samples = np.random.choice(a, size=core_shape, replace=False) - return (rng, samples) - return choice_without_replacement_rv +@numba_funcify.register +def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs): + raise RuntimeError( + "It is necessary to replace RandomVariable with RandomVariableWithCoreShape. " + "This is done by the default rewrites during compilation." + ) -@numba_funcify.register(ptr.PermutationRV) -def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): - size_is_none = isinstance(op.size_param(node).type, NoneTypeT) - batch_ndim = op.batch_ndim(node) - x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] +@numba_funcify.register +def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): + core_shape = node.inputs[0] - @numba_basic.numba_njit - def permutation_rv(rng, size, x): - if batch_ndim: - x_core_shape = x.shape[x_batch_ndim:] - if size_is_none: - size = x.shape[:batch_ndim] - else: - size = numba_ndarray.to_fixed_tuple(size, batch_ndim) - x = np.broadcast_to(x, size + x_core_shape) + [rv_node] = op.fgraph.apply_nodes + rv_op: RandomVariable = rv_node.op + rng_param = rv_op.rng_param(rv_node) + if isinstance(rng_param.type, RandomStateType): + raise TypeError("Numba does not support NumPy `RandomStateType`s") + size = rv_op.size_param(rv_node) + dist_params = rv_op.dist_params(rv_node) + size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) + core_shape_len = get_vector_length(core_shape) + inplace = rv_op.inplace - samples = np.empty(size + x_core_shape, dtype=x.dtype) - for index in np.ndindex(size): - samples[index] = np.random.permutation(x[index]) + core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + nin = 1 + len(dist_params) # rng + params + core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) - else: - samples = np.random.permutation(x) + batch_ndim = rv_op.batch_ndim(rv_node) + + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params) + ) + output_bc_patterns = encode_literals( + (rv_node.outputs[1].type.broadcastable[:batch_ndim],) + ) + output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) + inplace_pattern = encode_literals(()) + + def random_wrapper(core_shape, rng, size, *dist_params): + if not inplace: + rng = copy(rng) + + draws = _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (rng,), + dist_params, + (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), + None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), + ) + return rng, draws + + def random(core_shape, rng, size, *dist_params): + pass - return (rng, samples) + @overload(random, jit_options=_jit_options) + def ov_random(core_shape, rng, size, *dist_params): + return random_wrapper - return permutation_rv + return random diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index a967675727..c60c4c546f 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs): # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of Scan? # The C-code defers it to the make_thunk phase - rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer + rewriter = ( + op.mode_instance.including("numba") + .excluding(*NUMBA._optimizer.exclude) + .optimizer + ) rewriter(op.fgraph) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 0511a4ce47..d9c634b6c9 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -5,6 +5,7 @@ import numpy as np from pytensor import config +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType @@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: _vectorize_node.register(Blockwise, _vectorize_not_needed) + + +class OpWithCoreShape(OpFromGraph): + """Generalizes an `Op` to include core shape as an additional input.""" diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 1702fd1e99..6b74aff6f9 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -2082,10 +2082,7 @@ def choice(a, size=None, replace=True, p=None, rng=None): # This is equivalent to the numpy implementation: # https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914 if p is None: - if rng is not None and isinstance(rng.type, RandomStateType): - idxs = randint(0, a_size, size=size, rng=rng) - else: - idxs = integers(0, a_size, size=size, rng=rng) + idxs = integers(0, a_size, size=size, rng=rng) else: idxs = categorical(p, size=size, rng=rng) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 5a31448f68..685983830c 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -19,6 +19,7 @@ get_vector_length, infer_static_shape, ) +from pytensor.tensor.blockwise import OpWithCoreShape from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import ( compute_batch_shape, @@ -476,3 +477,11 @@ def vectorize_random_variable( size = concatenate([new_size_dims, size]) return op.make_node(rng, size, *dist_params) + + +class RandomVariableWithCoreShape(OpWithCoreShape): + """Generalizes a random variable `Op` to include a core shape parameter.""" + + def __str__(self): + [rv_node] = self.fgraph.apply_nodes + return f"[{rv_node.op!s}]" diff --git a/pytensor/tensor/random/rewriting/__init__.py b/pytensor/tensor/random/rewriting/__init__.py index 2c32c16b33..ba1c8846aa 100644 --- a/pytensor/tensor/random/rewriting/__init__.py +++ b/pytensor/tensor/random/rewriting/__init__.py @@ -4,7 +4,8 @@ # isort: off -# Register JAX specializations +# Register Numba and JAX specializations +import pytensor.tensor.random.rewriting.numba import pytensor.tensor.random.rewriting.jax # isort: on diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py new file mode 100644 index 0000000000..fe170f4718 --- /dev/null +++ b/pytensor/tensor/random/rewriting/numba.py @@ -0,0 +1,88 @@ +from pytensor.compile import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pytensor.tensor import as_tensor, constant +from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape +from pytensor.tensor.rewriting.shape import ShapeFeature + + +@node_rewriter([RandomVariable]) +def introduce_explicit_core_shape_rv(fgraph, node): + """Introduce the core shape of a RandomVariable. + + We wrap RandomVariable graphs into a RandomVariableWithCoreShape OpFromGraph + that has an extra "non-functional" input that represents the core shape of the random variable. + This core_shape is used by the numba backend to pre-allocate the output array. + + If available, the core shape is extracted from the shape feature of the graph, + which has a higher change of having been simplified, optimized, constant-folded. + If missing, we fall back to the op._supp_shape_from_params method. + + This rewrite is required for the numba backend implementation of RandomVariable. + + Example + ------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,)) + pytensor.dprint(x, print_type=True) + # dirichlet_rv{"(a)->(a)"}.1 [id A] + # ├─ RNG() [id B] + # ├─ [5] [id C] + # └─ ExpandDims{axis=0} [id D] + # └─ [1 2 3] [id E] + + # After the rewrite, note the new core shape input [3] [id B] + fn = pytensor.function([], x, mode="NUMBA") + pytensor.dprint(fn.maker.fgraph) + # [dirichlet_rv{"(a)->(a)"}].1 [id A] 0 + # ├─ [3] [id B] + # ├─ RNG() [id C] + # ├─ [5] [id D] + # └─ [[1 2 3]] [id E] + # Inner graphs: + # [dirichlet_rv{"(a)->(a)"}] [id A] + # ← dirichlet_rv{"(a)->(a)"}.0 [id F] + # ├─ *1- [id G] + # ├─ *2- [id H] + # └─ *3- [id I] + # ← dirichlet_rv{"(a)->(a)"}.1 [id F] + # └─ ··· + """ + op: RandomVariable = node.op # type: ignore[annotation-unchecked] + + next_rng, rv = node.outputs + shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] + if shape_feature: + core_shape = [ + shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) + ] + else: + core_shape = op._supp_shape_from_params(op.dist_params(node)) + + if len(core_shape) == 0: + core_shape = constant([], dtype="int64") + else: + core_shape = as_tensor(core_shape) + + return ( + RandomVariableWithCoreShape( + [core_shape, *node.inputs], + node.outputs, + destroy_map={0: [1]} if op.inplace else None, + ) + .make_node(core_shape, *node.inputs) + .outputs + ) + + +optdb.register( + introduce_explicit_core_shape_rv.__name__, + out2in(introduce_explicit_core_shape_rv), + "numba", + position=100, +) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 17022102e2..2ec1afa930 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -740,13 +740,13 @@ def apply(self, fgraph): # Register it after merge1 optimization at 0. We don't want to track # the shape of merged node. -pytensor.compile.mode.optdb.register( # type: ignore +pytensor.compile.mode.optdb.register( "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1 ) # Not enabled by default for now. Some crossentropy opt use the # shape_feature. They are at step 2.01. uncanonicalize is at step # 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. -pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore +pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) def local_reshape_chain(op): diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 52fa8dc502..d73c19752b 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/basic.py pytensor/link/numba/dispatch/elemwise.py -pytensor/link/numba/dispatch/random.py pytensor/link/numba/dispatch/scan.py pytensor/printing.py pytensor/raise_op.py diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 078f952535..20ecdc3002 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -29,7 +29,6 @@ from pytensor.graph.type import Type from pytensor.ifelse import ifelse from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch import numba_typify from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar @@ -120,7 +119,7 @@ def perform(self, node, inputs, outputs): my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nout = 2 opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) -numba_mode = Mode(NumbaLinker(), opts) +numba_mode = Mode(NumbaLinker(), opts.including("numba")) py_mode = Mode("py", opts) rng = np.random.default_rng(42849) @@ -229,6 +228,7 @@ def compare_numba_and_py( numba_mode=numba_mode, py_mode=py_mode, updates=None, + eval_obj_mode: bool = True, ) -> tuple[Callable, Any]: """Function to compare python graph output and Numba compiled output for testing equality @@ -247,6 +247,8 @@ def compare_numba_and_py( provided uses `np.testing.assert_allclose`. updates Updates to be passed to `pytensor.function`. + eval_obj_mode : bool, default True + Whether to do an isolated call in object mode. Used for test coverage Returns ------- @@ -283,7 +285,8 @@ def assert_fn(x, y): numba_res = pytensor_numba_fn(*inputs) # Get some coverage - eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) + if eval_obj_mode: + eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) if len(fn_outputs) > 1: for j, p in zip(numba_res, py_res): @@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar): assert res == expected -@pytest.mark.parametrize( - "input, wrapper_fn, check_fn", - [ - ( - np.random.RandomState(1), - numba_typify, - lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]), - ) - ], -) -def test_box_unbox(input, wrapper_fn, check_fn): - input = wrapper_fn(input) - - pass_through = numba.njit(lambda x: x) - res = pass_through(input) - - assert isinstance(res, type(input)) - assert check_fn(res, input) - - @pytest.mark.parametrize( "x, indices", [ diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 71639244b2..b966ed2870 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -8,13 +8,13 @@ import pytensor.tensor as pt import pytensor.tensor.random.basic as ptr from pytensor import shared +from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from tests.link.numba.test_basic import ( compare_numba_and_py, - eval_python_only, numba_mode, set_test_value, ) @@ -28,37 +28,139 @@ rng = np.random.default_rng(42849) -@pytest.mark.xfail( - reason="Most RVs are not working correctly with explicit expand_dims" -) +@pytest.mark.parametrize("mu_shape", [(), (3,), (5, 1)]) +@pytest.mark.parametrize("sigma_shape", [(), (1,), (5, 3)]) +@pytest.mark.parametrize("size_type", (None, "constant", "mutable")) +def test_random_size(mu_shape, sigma_shape, size_type): + test_value_rng = np.random.default_rng(637) + mu = test_value_rng.normal(size=mu_shape) + sigma = np.exp(test_value_rng.normal(size=sigma_shape)) + + # For testing + rng = np.random.default_rng(123) + pt_rng = shared(rng) + if size_type is None: + size = None + pt_size = None + elif size_type == "constant": + size = (5, 3) + pt_size = pt.as_tensor(size, dtype="int64") + else: + size = (5, 3) + pt_size = shared(np.array(size, dtype="int64"), shape=(2,)) + + next_rng, x = pt.random.normal(mu, sigma, rng=pt_rng, size=pt_size).owner.outputs + fn = function([], x, updates={pt_rng: next_rng}, mode="NUMBA") + + res1 = fn() + np.testing.assert_allclose( + res1, + rng.normal(mu, sigma, size=size), + ) + + res2 = fn() + np.testing.assert_allclose( + res2, + rng.normal(mu, sigma, size=size), + ) + + pt_rng.set_value(np.random.default_rng(123)) + res3 = fn() + np.testing.assert_array_equal(res1, res3) + + if size_type == "mutable" and len(mu_shape) < 2 and len(sigma_shape) < 2: + pt_size.set_value(np.array((6, 3), dtype="int64")) + res4 = fn() + assert res4.shape == (6, 3) + + +def test_rng_copy(): + rng = shared(np.random.default_rng(123)) + x = pt.random.normal(rng=rng) + + fn = function([], x, mode="NUMBA") + np.testing.assert_array_equal(fn(), fn()) + + rng.type.values_eq(rng.get_value(), np.random.default_rng(123)) + + +def test_rng_non_default_update(): + rng = shared(np.random.default_rng(1)) + rng_new = shared(np.random.default_rng(2)) + + x = pt.random.normal(size=10, rng=rng) + fn = function([], x, updates={rng: rng_new}, mode=numba_mode) + + ref = np.random.default_rng(1).normal(size=10) + np.testing.assert_allclose(fn(), ref) + + ref = np.random.default_rng(2).normal(size=10) + np.testing.assert_allclose(fn(), ref) + np.testing.assert_allclose(fn(), ref) + + +def test_categorical_rv(): + """This is also a smoke test for a vector input scalar output RV""" + p = np.array( + [ + [ + [1.0, 0, 0, 0], + [0.0, 1.0, 0, 0], + [0.0, 0, 1.0, 0], + ], + [ + [0, 0, 0, 1.0], + [0, 0, 0, 1.0], + [0, 0, 0, 1.0], + ], + ] + ) + x = pt.random.categorical(p=p, size=None) + updates = {x.owner.inputs[0]: x.owner.outputs[0]} + fn = function([], x, updates=updates, mode="NUMBA") + res = fn() + assert np.all(np.argmax(p, axis=-1) == res) + + # Batch size + x = pt.random.categorical(p=p, size=(3, *p.shape[:-1])) + fn = function([], x, updates=updates, mode="NUMBA") + new_res = fn() + assert new_res.shape == (3, *res.shape) + for new_res_row in new_res: + assert np.all(new_res_row == res) + + +def test_multivariate_normal(): + """This is also a smoke test for a multivariate RV""" + rng = np.random.default_rng(123) + + x = pt.random.multivariate_normal( + mean=np.zeros((3, 2)), + cov=np.eye(2), + rng=shared(rng), + ) + + fn = function([], x, mode="NUMBA") + np.testing.assert_array_equal( + fn(), + rng.multivariate_normal(np.zeros(2), np.eye(2), size=(3,)), + ) + + @pytest.mark.parametrize( "rv_op, dist_args, size", [ ( - ptr.normal, + ptr.uniform, [ - set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), set_test_value( pt.dscalar(), np.array(1.0, dtype=np.float64), ), - ], - pt.as_tensor([3, 2]), - ), - ( - ptr.uniform, - [ set_test_value( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( - pt.dscalar(), - np.array(1.0, dtype=np.float64), - ), ], pt.as_tensor([3, 2]), ), @@ -94,7 +196,7 @@ ], pt.as_tensor([3, 2]), ), - pytest.param( + ( ptr.pareto, [ set_test_value( @@ -107,7 +209,6 @@ ), ], pt.as_tensor([3, 2]), - marks=pytest.mark.xfail(reason="Not implemented"), ), ( ptr.exponential, @@ -153,7 +254,7 @@ ], pt.as_tensor([3, 2]), ), - ( + pytest.param( ptr.hypergeometric, [ set_test_value( @@ -170,6 +271,7 @@ ), ], pt.as_tensor([3, 2]), + marks=pytest.mark.xfail, # Not implemented ), ( ptr.wald, @@ -262,33 +364,70 @@ None, ), ( - ptr.randint, + ptr.beta, [ set_test_value( - pt.lscalar(), - np.array(0, dtype=np.int64), + pt.dvector(), + np.array([1.0, 2.0], dtype=np.float64), ), set_test_value( - pt.lscalar(), - np.array(5, dtype=np.int64), + pt.dscalar(), + np.array(1.0, dtype=np.float64), ), ], - pt.as_tensor([3, 2]), + (2,), ), - pytest.param( - ptr.multivariate_normal, + ( + ptr._gamma, + [ + set_test_value( + pt.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + pt.dvector(), + np.array([0.5, 3.0], dtype=np.float64), + ), + ], + (2,), + ), + ( + ptr.chisquare, + [ + set_test_value( + pt.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ) + ], + (2,), + ), + ( + ptr.negative_binomial, [ set_test_value( - pt.dmatrix(), - np.array([[1, 2], [3, 4]], dtype=np.float64), + pt.lvector(), + np.array([100, 200], dtype=np.int64), ), set_test_value( - pt.tensor(dtype="float64", shape=(1, None, None)), - np.eye(2)[None, ...], + pt.dscalar(), + np.array(0.09, dtype=np.float64), ), ], - pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [4, 3, 2])), - marks=pytest.mark.xfail(reason="Not implemented"), + (2,), + ), + ( + ptr.vonmises, + [ + set_test_value( + pt.dvector(), + np.array([-0.5, 0.5], dtype=np.float64), + ), + set_test_value( + pt.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), ), ( ptr.permutation, @@ -312,17 +451,21 @@ [ set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), set_test_value( - pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64) + pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) ), ], - (), + (pt.as_tensor([2, 3])), ), - ( + pytest.param( partial(ptr.choice, replace=False), [ set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)), ], pt.as_tensor([2]), + marks=pytest.mark.xfail( + AssertionError, + reason="Not aligned with NumPy implementation", + ), ), pytest.param( partial(ptr.choice, replace=False), @@ -331,28 +474,23 @@ ], pt.as_tensor([2]), marks=pytest.mark.xfail( - raises=ValueError, - reason="Numba random.choice does not support >=1D `a`", + raises=AssertionError, + reason="Not aligned with NumPy implementation", ), ), - pytest.param( + ( # p must be passed by kwarg lambda a, p, size, rng: ptr.choice( a, p=p, size=size, replace=False, rng=rng ), [ set_test_value(pt.vector(), np.arange(5, dtype=np.float64)), - # Boring p, because the variable is not truly "aligned" set_test_value( pt.dvector(), np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64), ), ], - (), - marks=pytest.mark.xfail( - raises=Exception, # numba.TypeError - reason="Numba random.choice does not support `p` parameter", - ), + pt.as_tensor([2]), ), pytest.param( # p must be passed by kwarg @@ -361,23 +499,31 @@ ), [ set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), - # Boring p, because the variable is not truly "aligned" set_test_value( - pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64) + pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) ), ], (), - marks=pytest.mark.xfail( - raises=ValueError, - reason="Numba random.choice does not support >=1D `a`", + ), + pytest.param( + # p must be passed by kwarg + lambda a, p, size, rng: ptr.choice( + a, p=p, size=size, replace=False, rng=rng ), + [ + set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), + set_test_value( + pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) + ), + ], + (pt.as_tensor([2, 1])), ), ], ids=str, ) def test_aligned_RandomVariable(rv_op, dist_args, size): """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29402)) g = rv_op(*dist_args, size=size, rng=rng) g_fg = FunctionGraph(outputs=[g]) @@ -388,45 +534,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): for i in g_fg.inputs if not isinstance(i, SharedVariable | Constant) ], + eval_obj_mode=False, # No python impl ) -@pytest.mark.xfail(reason="Test is not working correctly with explicit expand_dims") @pytest.mark.parametrize( "rv_op, dist_args, base_size, cdf_name, params_conv", [ - ( - ptr.beta, - [ - set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - pt.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "beta", - lambda *args: args, - ), - ( - ptr._gamma, - [ - set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - pt.dvector(), - np.array([0.5, 3.0], dtype=np.float64), - ), - ], - (2,), - "gamma", - lambda a, b: (a, 0.0, b), - ), ( ptr.cauchy, [ @@ -443,18 +557,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): "cauchy", lambda *args: args, ), - ( - ptr.chisquare, - [ - set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ) - ], - (2,), - "chi2", - lambda *args: args, - ), ( ptr.gumbel, [ @@ -471,49 +573,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): "gumbel_r", lambda *args: args, ), - ( - ptr.negative_binomial, - [ - set_test_value( - pt.lvector(), - np.array([100, 200], dtype=np.int64), - ), - set_test_value( - pt.dscalar(), - np.array(0.09, dtype=np.float64), - ), - ], - (2,), - "nbinom", - lambda *args: args, - ), - pytest.param( - ptr.vonmises, - [ - set_test_value( - pt.dvector(), - np.array([-0.5, 0.5], dtype=np.float64), - ), - set_test_value( - pt.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "vonmises_line", - lambda mu, kappa: (kappa, mu), - marks=pytest.mark.xfail( - reason=( - "Numba's parameterization of `vonmises` does not match NumPy's." - "See https://github.com/numba/numba/issues/7886" - ) - ), - ), ], ) def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): """Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29402)) g = rv_op(*dist_args, size=(2000, *base_size), rng=rng) g_fn = function(dist_args, g, mode=numba_mode) samples = g_fn( @@ -534,78 +598,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ assert test_res.pvalue > 0.1 -@pytest.mark.parametrize( - "dist_args, size, cm", - [ - pytest.param( - [ - set_test_value( - pt.dvector(), - np.array([100000, 1, 1], dtype=np.float64), - ), - ], - None, - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 3), - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1]], - dtype=np.float64, - ), - ), - ], - (5, 4, 3), - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 4), - pytest.raises( - ValueError, match="objects cannot be broadcast to a single shape" - ), - ), - ], -) -def test_CategoricalRV(dist_args, size, cm): - rng = shared(np.random.RandomState(29402)) - g = ptr.categorical(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - @pytest.mark.parametrize( "a, size, cm", [ @@ -637,21 +629,21 @@ def test_CategoricalRV(dist_args, size, cm): ), ), (10, 4), - pytest.raises(ValueError, match="operands could not be broadcast together"), + pytest.raises( + ValueError, + match="Vectorized input 0 has an incompatible shape in axis 1.", + ), ), ], ) def test_DirichletRV(a, size, cm): - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29402)) g = ptr.dirichlet(a, size=size, rng=rng) g_fn = function([a], g, mode=numba_mode) with cm: a_val = a.tag.test_value - # For coverage purposes only... - eval_python_only([a], [g], [a_val]) - all_samples = [] for i in range(1000): samples = g_fn(a_val) @@ -662,48 +654,34 @@ def test_DirichletRV(a, size, cm): assert np.allclose(res, exp_res, atol=1e-4) -@pytest.mark.xfail(reason="RandomState is not aligned with explicit expand_dims") -def test_RandomState_updates(): - rng = shared(np.random.RandomState(1)) - rng_new = shared(np.random.RandomState(2)) - - x = pt.random.normal(size=10, rng=rng) - res = function([], x, updates={rng: rng_new}, mode=numba_mode)() +def test_rv_inside_ofg(): + rng_np = np.random.default_rng(562) + rng = shared(rng_np) - ref = np.random.RandomState(2).normal(size=10) - assert np.allclose(res, ref) + rng_dummy = rng.type() + next_rng_dummy, rv_dummy = ptr.normal( + 0, 1, size=(3, 2), rng=rng_dummy + ).owner.outputs + out_dummy = rv_dummy.T + next_rng, out = OpFromGraph([rng_dummy], [next_rng_dummy, out_dummy])(rng) + fn = function([], out, updates={rng: next_rng}, mode=numba_mode) -def test_random_Generator(): - rng = shared(np.random.default_rng(29402)) - g = ptr.normal(rng=rng) - g_fg = FunctionGraph(outputs=[g]) + res1, res2 = fn(), fn() + assert res1.shape == (2, 3) - with pytest.raises(TypeError): - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) + np.testing.assert_allclose(res1, rng_np.normal(0, 1, size=(3, 2)).T) + np.testing.assert_allclose(res2, rng_np.normal(0, 1, size=(3, 2)).T) @pytest.mark.parametrize( "batch_dims_tester", [ - pytest.param( - batched_unweighted_choice_without_replacement_tester, - marks=pytest.mark.xfail(raises=NotImplementedError), - ), - pytest.param( - batched_weighted_choice_without_replacement_tester, - marks=pytest.mark.xfail(raises=NotImplementedError), - ), + batched_unweighted_choice_without_replacement_tester, + batched_weighted_choice_without_replacement_tester, batched_permutation_tester, ], ) def test_unnatural_batched_dims(batch_dims_tester): """Tests for RVs that don't have natural batch dims in Numba API.""" - batch_dims_tester(mode="NUMBA", rng_ctor=np.random.RandomState) + batch_dims_tester(mode="NUMBA") diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index a72a40df53..be69f55f4e 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -77,15 +77,13 @@ ), # nit-sot, shared input/output ( - lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal( - 0, 1, name="a" - ), + lambda: RandomStream(seed=1930).normal(0, 1, name="a"), [], [{}], [], 3, [], - [np.array([-1.63408257, 0.18046406, 2.43265803])], + [np.array([0.50100236, 2.16822932, 1.36326596])], lambda op: op.info.n_shared_outs > 0, ), # mit-sot (that's also a type of sit-sot) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 84385e5bc3..4a26aa69f4 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1452,9 +1452,7 @@ def test_permutation_shape(): assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5) -def batched_unweighted_choice_without_replacement_tester( - mode="FAST_RUN", rng_ctor=np.random.default_rng -): +def batched_unweighted_choice_without_replacement_tester(mode="FAST_RUN"): """Test unweighted choice without replacement with batched ndims. This has no corresponding in numpy, but is supported for consistency within the @@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester( It can be triggered by manual buiding the Op or during automatic vectorization. """ - rng = shared(rng_ctor()) + rng = shared(np.random.default_rng()) # Batched a implicit size rv_op = ChoiceWithoutReplacement( @@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester( assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) -def batched_weighted_choice_without_replacement_tester( - mode="FAST_RUN", rng_ctor=np.random.default_rng -): +def batched_weighted_choice_without_replacement_tester(mode="FAST_RUN"): """Test weighted choice without replacement with batched ndims. This has no corresponding in numpy, but is supported for consistency within the @@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester( It can be triggered by manual buiding the Op or during automatic vectorization. """ - rng = shared(rng_ctor()) + rng = shared(np.random.default_rng()) rv_op = ChoiceWithoutReplacement( signature="(a0,a1),(a0),(1)->(s0,a1)", @@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester( assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10)) -def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): +def batched_permutation_tester(mode="FAST_RUN"): """Test permutation with batched ndims. This has no corresponding in numpy, but is supported for consistency within the @@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): It can be triggered by manual buiding the Op or during automatic vectorization. """ - rng = shared(rng_ctor()) + rng = shared(np.random.default_rng()) rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64") x = np.arange(5 * 3 * 2).reshape((5, 3, 2)) From 69111c8d72a2c23abeff717e5d79ae938b2d40eb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 May 2024 20:58:07 +0200 Subject: [PATCH 15/15] Remove RandomState type in remaining backends --- pytensor/link/jax/dispatch/random.py | 12 +--- pytensor/link/numba/dispatch/random.py | 4 -- pytensor/tensor/random/__init__.py | 2 +- pytensor/tensor/random/basic.py | 68 +----------------- pytensor/tensor/random/op.py | 17 ++--- pytensor/tensor/random/type.py | 91 ------------------------ pytensor/tensor/random/utils.py | 10 +-- pytensor/tensor/random/var.py | 20 +++--- tests/link/jax/test_random.py | 64 ++++++----------- tests/scan/test_basic.py | 5 +- tests/tensor/random/test_basic.py | 22 ------ tests/tensor/random/test_op.py | 7 +- tests/tensor/random/test_type.py | 98 +------------------------- tests/tensor/random/test_utils.py | 12 ++-- tests/tensor/random/test_var.py | 14 ++-- 15 files changed, 52 insertions(+), 394 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 98b59d22b3..9a89bf1406 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -2,7 +2,7 @@ import jax import numpy as np -from numpy.random import Generator, RandomState +from numpy.random import Generator from numpy.random.bit_generator import ( # type: ignore[attr-defined] _coerce_to_uint32_array, ) @@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node): raise NotImplementedError(SIZE_NOT_COMPATIBLE) -@jax_typify.register(RandomState) -def jax_typify_RandomState(state, **kwargs): - state = state.get_state(legacy=False) - state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] - # XXX: Is this a reasonable approach? - state["jax_state"] = state["state"]["key"][0:2] - return state - - @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): state = rng.__getstate__() @@ -214,7 +205,6 @@ def sample_fn(rng, size, dtype, p): return sample_fn -@jax_sample_fn.register(ptr.RandIntRV) @jax_sample_fn.register(ptr.IntegersRV) @jax_sample_fn.register(ptr.UniformRV) def jax_sample_fn_uniform(op, node): diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index c7e2f24546..ad4269e06f 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -25,7 +25,6 @@ ) from pytensor.tensor import get_vector_length from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape -from pytensor.tensor.random.type import RandomStateType from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.utils import _parse_gufunc_signature @@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs [rv_node] = op.fgraph.apply_nodes rv_op: RandomVariable = rv_node.op - rng_param = rv_op.rng_param(rv_node) - if isinstance(rng_param.type, RandomStateType): - raise TypeError("Numba does not support NumPy `RandomStateType`s") size = rv_op.size_param(rv_node) dist_params = rv_op.dist_params(rv_node) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) diff --git a/pytensor/tensor/random/__init__.py b/pytensor/tensor/random/__init__.py index a1cd42f789..78994fd40c 100644 --- a/pytensor/tensor/random/__init__.py +++ b/pytensor/tensor/random/__init__.py @@ -2,5 +2,5 @@ import pytensor.tensor.random.rewriting import pytensor.tensor.random.utils from pytensor.tensor.random.basic import * -from pytensor.tensor.random.op import RandomState, default_rng +from pytensor.tensor.random.op import default_rng from pytensor.tensor.random.utils import RandomStream diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 6b74aff6f9..4a2c47b2af 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -9,15 +9,10 @@ from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import sqrt from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType from pytensor.tensor.random.utils import ( broadcast_params, normalize_size_param, ) -from pytensor.tensor.random.var import ( - RandomGeneratorSharedVariable, - RandomStateSharedVariable, -) try: @@ -645,7 +640,7 @@ def __call__( @classmethod def rng_fn_scipy( cls, - rng: np.random.Generator | np.random.RandomState, + rng: np.random.Generator, loc: np.ndarray | float, scale: np.ndarray | float, size: list[int] | int | None, @@ -1880,58 +1875,6 @@ def rng_fn(cls, rng, p, size): categorical = CategoricalRV() -class RandIntRV(RandomVariable): - r"""A discrete uniform random variable. - - Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s. - - """ - - name = "randint" - signature = "(),()->()" - dtype = "int64" - _print_name = ("randint", "\\operatorname{randint}") - - def __call__(self, low, high=None, size=None, **kwargs): - r"""Draw samples from a discrete uniform distribution. - - Signature - --------- - - `() -> ()` - - Parameters - ---------- - low - Lower boundary of the output interval. All values generated will - be greater than or equal to `low`, unless `high=None`, in which case - all values generated are greater than or equal to `0` and - smaller than `low` (exclusive). - high - Upper boundary of the output interval. All values generated - will be smaller than `high` (exclusive). - size - Sample shape. If the given size is `(m, n, k)`, then `m * n * k` - independent, identically distributed samples are - returned. Default is `None`, in which case a single - sample is returned. - - """ - if high is None: - low, high = 0, low - return super().__call__(low, high, size=size, **kwargs) - - def make_node(self, rng, *args, **kwargs): - if not isinstance( - getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable - ): - raise TypeError("`randint` is only available for `RandomStateType`s") - return super().make_node(rng, *args, **kwargs) - - -randint = RandIntRV() - - class IntegersRV(RandomVariable): r"""A discrete uniform random variable. @@ -1971,14 +1914,6 @@ def __call__(self, low, high=None, size=None, **kwargs): low, high = 0, low return super().__call__(low, high, size=size, **kwargs) - def make_node(self, rng, *args, **kwargs): - if not isinstance( - getattr(rng, "type", None), - RandomGeneratorType | RandomGeneratorSharedVariable, - ): - raise TypeError("`integers` is only available for `RandomGeneratorType`s") - return super().make_node(rng, *args, **kwargs) - integers = IntegersRV() @@ -2201,7 +2136,6 @@ def permutation(x, **kwargs): "permutation", "choice", "integers", - "randint", "categorical", "multinomial", "betabinom", diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 685983830c..ff310d3d4b 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -20,7 +20,7 @@ infer_static_shape, ) from pytensor.tensor.blockwise import OpWithCoreShape -from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import ( compute_batch_shape, explicit_expand_dims, @@ -324,9 +324,8 @@ def make_node(self, rng, size, *dist_params): Parameters ---------- - rng: RandomGeneratorType or RandomStateType - Existing PyTensor `Generator` or `RandomState` object to be used. Creates a - new one, if `None`. + rng: RandomGeneratorType + Existing PyTensor `Generator` object to be used. Creates a new one, if `None`. size: int or Sequence NumPy-like size parameter. dtype: str @@ -354,7 +353,7 @@ def make_node(self, rng, size, *dist_params): rng = pytensor.shared(np.random.default_rng()) elif not isinstance(rng.type, RandomType): raise TypeError( - "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" + "The type of rng should be an instance of RandomGeneratorType " ) inferred_shape = self._infer_shape(size, dist_params) @@ -436,14 +435,6 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed) -class RandomStateConstructor(AbstractRNGConstructor): - random_type = RandomStateType() - random_constructor = "RandomState" - - -RandomState = RandomStateConstructor() - - class DefaultGeneratorMakerOp(AbstractRNGConstructor): random_type = RandomGeneratorType() random_constructor = "default_rng" diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 527d3f3d6b..7f2a156271 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T): return a._bit_generator is b._bit_generator # type: ignore[attr-defined] -class RandomStateType(RandomType[np.random.RandomState]): - r"""A Type wrapper for `numpy.random.RandomState`. - - The reason this exists (and `Generic` doesn't suffice) is that - `RandomState` objects that would appear to be equal do not compare equal - with the ``==`` operator. - - This `Type` also works with a ``dict`` derived from - `RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter` - is explicitly set to ``True``. - - """ - - def __repr__(self): - return "RandomStateType" - - def filter(self, data, strict: bool = False, allow_downcast=None): - """ - XXX: This doesn't convert `data` to the same type of underlying RNG type - as `self`. It really only checks that `data` is of the appropriate type - to be a valid `RandomStateType`. - - In other words, it serves as a `Type.is_valid_value` implementation, - but, because the default `Type.is_valid_value` depends on - `Type.filter`, we need to have it here to avoid surprising circular - dependencies in sub-classes. - """ - if isinstance(data, np.random.RandomState): - return data - - if not strict and isinstance(data, dict): - gen_keys = ["bit_generator", "gauss", "has_gauss", "state"] - state_keys = ["key", "pos"] - - for key in gen_keys: - if key not in data: - raise TypeError() - - for key in state_keys: - if key not in data["state"]: - raise TypeError() - - state_key = data["state"]["key"] - if state_key.shape == (624,) and state_key.dtype == np.uint32: - # TODO: Add an option to convert to a `RandomState` instance? - return data - - raise TypeError() - - @staticmethod - def values_eq(a, b): - sa = a if isinstance(a, dict) else a.get_state(legacy=False) - sb = b if isinstance(b, dict) else b.get_state(legacy=False) - - def _eq(sa, sb): - for key in sa: - if isinstance(sa[key], dict): - if not _eq(sa[key], sb[key]): - return False - elif isinstance(sa[key], np.ndarray): - if not np.array_equal(sa[key], sb[key]): - return False - else: - if sa[key] != sb[key]: - return False - - return True - - return _eq(sa, sb) - - def __eq__(self, other): - return type(self) == type(other) - - def __hash__(self): - return hash(type(self)) - - -# Register `RandomStateType`'s C code for `ViewOp`. -pytensor.compile.register_view_op_c_code( - RandomStateType, - """ - Py_XDECREF(%(oname)s); - %(oname)s = %(iname)s; - Py_XINCREF(%(oname)s); - """, - 1, -) - -random_state_type = RandomStateType() - - class RandomGeneratorType(RandomType[np.random.Generator]): r"""A Type wrapper for `numpy.random.Generator`. diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 38329bbae7..9ddedb34b1 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -209,9 +209,7 @@ def __init__( self, seed: int | None = None, namespace: ModuleType | None = None, - rng_ctor: Literal[ - np.random.RandomState, np.random.Generator - ] = np.random.default_rng, + rng_ctor: Literal[np.random.Generator] = np.random.default_rng, ): if namespace is None: from pytensor.tensor.random import basic # pylint: disable=import-self @@ -223,12 +221,6 @@ def __init__( self.default_instance_seed = seed self.state_updates = [] self.gen_seedgen = np.random.SeedSequence(seed) - - if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState): - # The legacy state does not accept `SeedSequence`s directly - def rng_ctor(seed): - return np.random.RandomState(np.random.MT19937(seed)) - self.rng_ctor = rng_ctor def __getattr__(self, obj): diff --git a/pytensor/tensor/random/var.py b/pytensor/tensor/random/var.py index c03b3046ab..09fef393e6 100644 --- a/pytensor/tensor/random/var.py +++ b/pytensor/tensor/random/var.py @@ -3,17 +3,12 @@ import numpy as np from pytensor.compile.sharedvalue import SharedVariable, shared_constructor -from pytensor.tensor.random.type import random_generator_type, random_state_type - - -class RandomStateSharedVariable(SharedVariable): - def __str__(self): - return self.name or f"RandomStateSharedVariable({self.container!r})" +from pytensor.tensor.random.type import random_generator_type class RandomGeneratorSharedVariable(SharedVariable): def __str__(self): - return self.name or f"RandomGeneratorSharedVariable({self.container!r})" + return self.name or f"RNG({self.container!r})" @shared_constructor.register(np.random.RandomState) @@ -23,11 +18,12 @@ def randomgen_constructor( ): r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" if isinstance(value, np.random.RandomState): - rng_sv_type = RandomStateSharedVariable - rng_type = random_state_type - elif isinstance(value, np.random.Generator): - rng_sv_type = RandomGeneratorSharedVariable - rng_type = random_generator_type + raise TypeError( + "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead." + ) + + rng_sv_type = RandomGeneratorSharedVariable + rng_type = random_generator_type if not borrow: value = copy.deepcopy(value) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 5b3ca0c9c3..dfbc888e30 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -49,7 +49,7 @@ def test_random_RandomStream(): assert not np.array_equal(jax_res_1, jax_res_2) -@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng)) +@pytest.mark.parametrize("rng_ctor", (np.random.default_rng,)) def test_random_updates(rng_ctor): original_value = rng_ctor(seed=98) rng = shared(original_value, name="original_rng", borrow=False) @@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality(): "poisson", lambda *args: args, ), - ( - ptr.randint, - [ - set_test_value( - pt.lscalar(), - np.array(0, dtype=np.int64), - ), - set_test_value( # high-value necessary since test on cdf - pt.lscalar(), - np.array(1000, dtype=np.int64), - ), - ], - (), - "randint", - lambda *args: args, - ), ( ptr.integers, [ @@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c The parameters passed to the op. """ - if rv_op is ptr.integers: - # Integers only accepts Generator, not RandomState - rng = shared(np.random.default_rng(29402)) - else: - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29403)) g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) g_fn = compile_random_function(dist_params, g, mode=jax_mode) samples = g_fn( @@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn): @pytest.mark.parametrize("size", [(), (4,)]) def test_random_bernoulli(size): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() @@ -553,7 +533,7 @@ def test_random_bernoulli(size): def test_random_mvnormal(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) mu = np.ones(4) cov = np.eye(4) @@ -571,7 +551,7 @@ def test_random_mvnormal(): ], ) def test_random_dirichlet(parameter, size): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() @@ -598,7 +578,7 @@ def test_random_choice(): assert np.all(samples % 2 == 1) # `replace=False` and `p is None` - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() @@ -607,7 +587,7 @@ def test_random_choice(): assert len(np.unique(samples)) == 98 # `replace=False` and `p is not None` - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.choice( 8, p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]), @@ -625,7 +605,7 @@ def test_random_choice(): def test_random_categorical(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() @@ -642,7 +622,7 @@ def test_random_categorical(): def test_random_permutation(): array = np.arange(4) - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) g = pt.random.permutation(array, rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) permuted = g_fn() @@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester): def test_random_geometric(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) p = np.array([0.3, 0.7]) g = pt.random.geometric(p, size=(10_000, 2), rng=rng) g_fn = compile_random_function([], g, mode=jax_mode) @@ -674,7 +654,7 @@ def test_random_geometric(): def test_negative_binomial(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) @@ -688,7 +668,7 @@ def test_negative_binomial(): @pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro") def test_binomial(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) @@ -702,7 +682,7 @@ def test_binomial(): not numpyro_available, reason="BetaBinomial dispatch requires numpyro" ) def test_beta_binomial(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) n = np.array([10, 40]) a = np.array([1.5, 13]) b = np.array([0.5, 9]) @@ -721,7 +701,7 @@ def test_beta_binomial(): not numpyro_available, reason="Multinomial dispatch requires numpyro" ) def test_multinomial(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) @@ -737,7 +717,7 @@ def test_multinomial(): def test_vonmises_mu_outside_circle(): # Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle # We test that the random draws from the JAX dispatch work as expected in these cases - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) mu = np.array([-30, 40]) kappa = np.array([100, 10]) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) @@ -781,7 +761,7 @@ def rng_fn(cls, rng, size): return 0 nonexistentrv = NonExistentRV() - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) out = nonexistentrv(rng=rng) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) @@ -816,7 +796,7 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn nonexistentrv = CustomRV() - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) out = nonexistentrv(rng=rng) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) with pytest.warns( @@ -836,7 +816,7 @@ def test_random_concrete_shape(): `size` parameter satisfies either of these criteria. """ - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) jax_fn = compile_random_function([x_pt], out, mode=jax_mode) @@ -844,7 +824,7 @@ def test_random_concrete_shape(): def test_random_concrete_shape_from_param(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(x_pt, 1, rng=rng) jax_fn = compile_random_function([x_pt], out, mode=jax_mode) @@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor(): slight improvement over their API. """ - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) jax_fn = compile_random_function([x_pt], out, mode=jax_mode) @@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple(): `jax_size_parameter_as_tuple` rewrite. """ - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) jax_fn = compile_random_function([x_pt], out, mode=jax_mode) @@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple(): reason="`size_pt` should be specified as a static argument", strict=True ) def test_random_concrete_shape_graph_input(): - rng = shared(np.random.RandomState(123)) + rng = shared(np.random.default_rng(123)) size_pt = pt.scalar() out = pt.random.normal(0, 1, size=size_pt, rng=rng) jax_fn = compile_random_function([size_pt], out, mode=jax_mode) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 63b2a53b22..343f539274 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct): class TestScan: @pytest.mark.parametrize( "rng_type", - [ - np.random.default_rng, - np.random.RandomState, - ], + [np.random.default_rng], ) def test_inner_graph_cloning(self, rng_type): r"""Scan should remove the updates-providing special properties on `RandomType`\s.""" diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4a26aa69f4..7d24a49228 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -51,7 +51,6 @@ pareto, permutation, poisson, - randint, rayleigh, standard_normal, t, @@ -1355,27 +1354,6 @@ def test_categorical_basic(): categorical.rng_fn(rng, p[None], size=(3,)) -def test_randint_samples(): - with pytest.raises(TypeError): - randint(10, rng=shared(np.random.default_rng())) - - rng = np.random.RandomState(2313) - compare_sample_values(randint, 10, None, rng=rng) - compare_sample_values(randint, 0, 1, rng=rng) - compare_sample_values(randint, 0, 1, size=[3], rng=rng) - compare_sample_values(randint, [0, 1, 2], 5, rng=rng) - compare_sample_values(randint, [0, 1, 2], 5, size=[3, 3], rng=rng) - compare_sample_values(randint, [0], [5], size=[1], rng=rng) - compare_sample_values(randint, pt.as_tensor_variable([-1]), [1], size=[1], rng=rng) - compare_sample_values( - randint, - pt.as_tensor_variable([-1]), - [1], - size=pt.as_tensor_variable([1]), - rng=rng, - ) - - def test_integers_samples(): with pytest.raises(TypeError): integers(10, rng=shared(np.random.RandomState())) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 35e5f49c28..2160cb83fe 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -8,7 +8,7 @@ from pytensor.tensor.math import eq from pytensor.tensor.random import normal from pytensor.tensor.random.basic import NormalRV -from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng +from pytensor.tensor.random.op import RandomVariable, default_rng from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import iscalar, tensor @@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags): @pytest.mark.parametrize( "seed, maker_op, numpy_res", [ - (3, RandomState, np.random.RandomState(3)), (3, default_rng, np.random.default_rng(3)), ], ) @@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags): # Testing the initialization when seed=None # Since internal states randomly generated, # we just check the output classes - z = function(inputs=[], outputs=[RandomState()])() - aes_res = z[0] - assert isinstance(aes_res, np.random.RandomState) - z = function(inputs=[], outputs=[default_rng()])() aes_res = z[0] assert isinstance(aes_res, np.random.Generator) diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index 53a3a6be8b..d289862347 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -7,9 +7,7 @@ from pytensor.compile.ops import ViewOp from pytensor.tensor.random.type import ( RandomGeneratorType, - RandomStateType, random_generator_type, - random_state_type, ) @@ -28,101 +26,9 @@ def test_view_op_c_code(): # rng_view, # mode=Mode(optimizer=None, linker=CLinker()), # ) - assert ViewOp.c_code_and_version[RandomStateType] assert ViewOp.c_code_and_version[RandomGeneratorType] -class TestRandomStateType: - def test_pickle(self): - rng_r = random_state_type() - - rng_pkl = pickle.dumps(rng_r) - rng_unpkl = pickle.loads(rng_pkl) - - assert rng_r != rng_unpkl - assert rng_r.type == rng_unpkl.type - assert hash(rng_r.type) == hash(rng_unpkl.type) - - def test_repr(self): - assert repr(random_state_type) == "RandomStateType" - - def test_filter(self): - rng_type = random_state_type - - rng = np.random.RandomState() - assert rng_type.filter(rng) is rng - - with pytest.raises(TypeError): - rng_type.filter(1) - - rng_dict = rng.get_state(legacy=False) - - assert rng_type.is_valid_value(rng_dict) is False - assert rng_type.is_valid_value(rng_dict, strict=False) - - rng_dict["state"] = {} - - assert rng_type.is_valid_value(rng_dict, strict=False) is False - - rng_dict = {} - assert rng_type.is_valid_value(rng_dict, strict=False) is False - - def test_values_eq(self): - rng_type = random_state_type - - rng_a = np.random.RandomState(12) - rng_b = np.random.RandomState(12) - rng_c = np.random.RandomState(123) - - bg = np.random.PCG64() - rng_d = np.random.RandomState(bg) - rng_e = np.random.RandomState(bg) - - bg_2 = np.random.Philox() - rng_f = np.random.RandomState(bg_2) - rng_g = np.random.RandomState(bg_2) - - assert rng_type.values_eq(rng_a, rng_b) - assert not rng_type.values_eq(rng_a, rng_c) - - assert not rng_type.values_eq(rng_a, rng_d) - assert not rng_type.values_eq(rng_d, rng_a) - - assert not rng_type.values_eq(rng_a, rng_d) - assert rng_type.values_eq(rng_d, rng_e) - - assert rng_type.values_eq(rng_f, rng_g) - assert not rng_type.values_eq(rng_g, rng_a) - assert not rng_type.values_eq(rng_e, rng_g) - - def test_may_share_memory(self): - bg1 = np.random.MT19937() - bg2 = np.random.MT19937() - - rng_a = np.random.RandomState(bg1) - rng_b = np.random.RandomState(bg2) - - rng_var_a = shared(rng_a, borrow=True) - rng_var_b = shared(rng_b, borrow=True) - - assert ( - random_state_type.may_share_memory( - rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True) - ) - is False - ) - - rng_c = np.random.RandomState(bg2) - rng_var_c = shared(rng_c, borrow=True) - - assert ( - random_state_type.may_share_memory( - rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True) - ) - is True - ) - - class TestRandomGeneratorType: def test_pickle(self): rng_r = random_generator_type() @@ -200,7 +106,7 @@ def test_may_share_memory(self): rng_var_b = shared(rng_b, borrow=True) assert ( - random_state_type.may_share_memory( + random_generator_type.may_share_memory( rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True) ) is False @@ -210,7 +116,7 @@ def test_may_share_memory(self): rng_var_c = shared(rng_c, borrow=True) assert ( - random_state_type.may_share_memory( + random_generator_type.may_share_memory( rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True) ) is True diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 28ee2b94e0..3616b2fd24 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -101,7 +101,7 @@ def test_tutorial(self): assert np.all(g() == g()) assert np.all(abs(nearly_zeros()) < 1e-5) - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_basics(self, rng_ctor): random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor) @@ -132,7 +132,7 @@ def test_basics(self, rng_ctor): assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val1, numpy_val1) - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_seed(self, rng_ctor): init_seed = 234 random = RandomStream(init_seed, rng_ctor=rng_ctor) @@ -176,7 +176,7 @@ def test_seed(self, rng_ctor): assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["state"] == ref_state["state"] - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_uniform(self, rng_ctor): # Test that RandomStream.uniform generates the same results as numpy # Check over two calls to see if the random state is correctly updated. @@ -195,7 +195,7 @@ def test_uniform(self, rng_ctor): assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val1, numpy_val1) - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_default_updates(self, rng_ctor): # Basic case: default_updates random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) @@ -244,7 +244,7 @@ def test_default_updates(self, rng_ctor): assert np.all(fn_e_val0 == fn_a_val0) assert np.all(fn_e_val1 == fn_e_val0) - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_multiple_rng_aliasing(self, rng_ctor): # Test that when we have multiple random number generators, we do not alias # the state_updates member. `state_updates` can be useful when attempting to @@ -257,7 +257,7 @@ def test_multiple_rng_aliasing(self, rng_ctor): assert rng1.state_updates is not rng2.state_updates assert rng1.gen_seedgen is not rng2.gen_seedgen - @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) + @pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_random_state_transfer(self, rng_ctor): # Test that random state can be transferred from one pytensor graph to another. diff --git a/tests/tensor/random/test_var.py b/tests/tensor/random/test_var.py index 47f5dcea48..279eb67f6c 100644 --- a/tests/tensor/random/test_var.py +++ b/tests/tensor/random/test_var.py @@ -4,9 +4,7 @@ from pytensor import shared -@pytest.mark.parametrize( - "rng", [np.random.RandomState(123), np.random.default_rng(123)] -) +@pytest.mark.parametrize("rng", [np.random.default_rng(123)]) def test_GeneratorSharedVariable(rng): s_rng_default = shared(rng) s_rng_True = shared(rng, borrow=True) @@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng): assert v == v0 == v1 -@pytest.mark.parametrize( - "rng", [np.random.RandomState(123), np.random.default_rng(123)] -) +@pytest.mark.parametrize("rng", [np.random.default_rng(123)]) def test_get_value_borrow(rng): s_rng = shared(rng) @@ -55,9 +51,7 @@ def test_get_value_borrow(rng): assert r_.standard_normal() == r_F.standard_normal() -@pytest.mark.parametrize( - "rng", [np.random.RandomState(123), np.random.default_rng(123)] -) +@pytest.mark.parametrize("rng", [np.random.default_rng(123)]) def test_get_value_internal_type(rng): s_rng = shared(rng) @@ -81,7 +75,7 @@ def test_get_value_internal_type(rng): assert r_.standard_normal() == r_F.standard_normal() -@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) +@pytest.mark.parametrize("rng_ctor", [np.random.default_rng]) def test_set_value_borrow(rng_ctor): s_rng = shared(rng_ctor(123))