From a18ccda90a71cea46cf8dd8310c996a24764e388 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Wed, 27 Sep 2023 11:07:22 -0500 Subject: [PATCH] Reparameterize GammaRV so beta is not inverted at each call Also fix wrong JAX implementation of Gamma and Pareto RVs --- pytensor/link/jax/dispatch/random.py | 11 +++---- pytensor/tensor/random/basic.py | 31 ++++++++++++------ pytensor/tensor/random/rewriting/jax.py | 10 +++--- tests/link/jax/test_random.py | 16 ++++++---- tests/link/numba/test_random.py | 10 ++++-- tests/tensor/random/test_basic.py | 42 ++++++++++++++++++------- 6 files changed, 79 insertions(+), 41 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 0981234db0..05d8957b6b 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -216,21 +216,20 @@ def sample_fn(rng, size, dtype, *parameters): @jax_sample_fn.register(aer.ParetoRV) @jax_sample_fn.register(aer.GammaRV) -def jax_sample_fn_shape_rate(op): - """JAX implementation of random variables in the shape-rate family. +def jax_sample_fn_shape_scale(op): + """JAX implementation of random variables in the shape-scale family. JAX only implements the standard version of random variables in the - shape-rate family. We thus need to rescale the results manually. + shape-scale family. We thus need to rescale the results manually. """ name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): + def sample_fn(rng, size, dtype, shape, scale): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) - (shape, rate) = parameters - sample = jax_op(sampling_key, shape, size, dtype) / rate + sample = jax_op(sampling_key, shape, size, dtype) * scale rng["jax_state"] = rng_key return (rng, sample) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 96c7913336..129bbf6af7 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1,4 +1,5 @@ import abc +import warnings from typing import List, Optional, Union import numpy as np @@ -419,7 +420,7 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs): lognormal = LogNormalRV() -class GammaRV(ScipyRandomVariable): +class GammaRV(RandomVariable): r"""A gamma continuous random variable. The probability density function for `gamma` in terms of the shape parameter @@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable): dtype = "floatX" _print_name = ("Gamma", "\\operatorname{Gamma}") - def __call__(self, shape, rate, size=None, **kwargs): + def __call__(self, shape, scale, size=None, **kwargs): r"""Draw samples from a gamma distribution. Signature @@ -455,8 +456,8 @@ def __call__(self, shape, rate, size=None, **kwargs): ---------- shape The shape :math:`\alpha` of the gamma distribution. Must be positive. - rate - The rate :math:`\beta` of the gamma distribution. Must be positive. + scale + The scale :math:`1/\beta` of the gamma distribution. Must be positive. size Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` independent, identically distributed random variables are @@ -464,14 +465,26 @@ def __call__(self, shape, rate, size=None, **kwargs): is returned. """ - return super().__call__(shape, 1.0 / rate, size=size, **kwargs) + return super().__call__(shape, scale, size=size, **kwargs) - @classmethod - def rng_fn_scipy(cls, rng, shape, scale, size): - return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng) +_gamma = GammaRV() + + +def gamma(shape, rate=None, scale=None, **kwargs): + # TODO: Remove helper when rate is deprecated + if rate is not None and scale is not None: + raise ValueError("Cannot specify both rate and scale") + elif rate is None and scale is None: + raise ValueError("Must specify scale") + elif rate is not None: + warnings.warn( + "Gamma rate argument is deprecated and will stop working, use scale instead", + FutureWarning, + ) + scale = 1.0 / rate -gamma = GammaRV() + return _gamma(shape, scale, **kwargs) class ChiSquareRV(RandomVariable): diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index 5ec2d30ea0..c96fad5242 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -15,9 +15,9 @@ LogNormalRV, NegBinomialRV, WaldRV, + _gamma, beta, binomial, - gamma, normal, poisson, uniform, @@ -92,7 +92,7 @@ def geometric_from_uniform(fgraph, node): @node_rewriter([NegBinomialRV]) def negative_binomial_from_gamma_poisson(fgraph, node): rng, *other_inputs, n, p = node.inputs - next_rng, g = gamma.make_node(rng, *other_inputs, n, p / (1 - p)).outputs + next_rng, g = _gamma.make_node(rng, *other_inputs, n, (1 - p) / p).outputs next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs return [next_rng, p] @@ -100,21 +100,21 @@ def negative_binomial_from_gamma_poisson(fgraph, node): @node_rewriter([InvGammaRV]) def inverse_gamma_from_gamma(fgraph, node): *other_inputs, shape, scale = node.inputs - next_rng, g = gamma.make_node(*other_inputs, shape, scale).outputs + next_rng, g = _gamma.make_node(*other_inputs, shape, 1 / scale).outputs return [next_rng, reciprocal(g)] @node_rewriter([ChiSquareRV]) def chi_square_from_gamma(fgraph, node): *other_inputs, df = node.inputs - next_rng, g = gamma.make_node(*other_inputs, df / 2, 1 / 2).outputs + next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs return [next_rng, g] @node_rewriter([GenGammaRV]) def generalized_gamma_from_gamma(fgraph, node): *other_inputs, alpha, p, lambd = node.inputs - next_rng, g = gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs + next_rng, g = _gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs g = (g ** reciprocal(p)) * lambd return [next_rng, cast(g, dtype=node.default_output().dtype)] diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 54e4e09307..b43469c182 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -4,7 +4,7 @@ import pytensor import pytensor.tensor as at -import pytensor.tensor.random as aer +import pytensor.tensor.random.basic as aer from pytensor.compile.function import function from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.graph.basic import Constant @@ -140,15 +140,15 @@ def test_random_updates_input_storage_order(): lambda *args: (0, args[0]), ), ( - aer.gamma, + aer._gamma, [ set_test_value( at.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), + at.dvector(), + np.array([0.5, 3.0], dtype=np.float64), ), ], (2,), @@ -235,11 +235,15 @@ def test_random_updates_input_storage_order(): set_test_value( at.dvector(), np.array([1.0, 2.0], dtype=np.float64), - ) + ), + set_test_value( + at.dvector(), + np.array([2.0, 10.0], dtype=np.float64), + ), ], (2,), "pareto", - lambda *args: args, + lambda shape, scale: (shape, 0.0, scale), ), ( aer.poisson, diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index f0ddf3525f..de6fa5ea6f 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -92,6 +92,10 @@ at.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), + set_test_value( + at.dvector(), + np.array([2.0, 10.0], dtype=np.float64), + ), ], at.as_tensor([3, 2]), marks=pytest.mark.xfail(reason="Not implemented"), @@ -316,15 +320,15 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): lambda *args: args, ), ( - aer.gamma, + aer._gamma, [ set_test_value( at.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), + at.dvector(), + np.array([0.5, 3.0], dtype=np.float64), ), ], (2,), diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4032b9a673..adc4177085 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -17,6 +17,7 @@ from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.tensor.random.basic import ( + _gamma, bernoulli, beta, betabinom, @@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size): ], ) def test_gamma_samples(a, b, size): - gamma_test_fn = fixed_scipy_rvs("gamma") - - def test_fn(shape, rate, **kwargs): - return gamma_test_fn(shape, scale=1.0 / rate, **kwargs) - compare_sample_values( - gamma, + _gamma, a, b, size=size, - test_fn=test_fn, ) +def test_gamma_deprecation_wrapper_fn(): + out = gamma(5.0, scale=0.5, size=(5,)) + assert out.type.shape == (5,) + assert out.owner.inputs[-1].eval() == 0.5 + + with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"): + out = gamma([5.0, 10.0], 2.0, size=None) + assert out.type.shape == (2,) + assert out.owner.inputs[-1].eval() == 0.5 + + with pytest.raises(ValueError, match="Must specify scale"): + gamma(5.0) + + with pytest.raises(ValueError, match="Cannot specify both rate and scale"): + gamma(5.0, rate=2.0, scale=0.5) + + @pytest.mark.parametrize( "df, size", [ @@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size): @pytest.mark.parametrize( - "alpha, size", + "alpha, scale, size", [ - (np.array(0.5, dtype=config.floatX), None), - (np.array(0.5, dtype=config.floatX), []), + (np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None), + (np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []), ( np.full((1, 2), 0.5, dtype=config.floatX), + np.array([0.5, 1.0], dtype=config.floatX), None, ), ], ) -def test_pareto_samples(alpha, size): - compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto")) +def test_pareto_samples(alpha, scale, size): + pareto_test_fn = fixed_scipy_rvs("pareto") + + def test_fn(shape, scale, **kwargs): + return pareto_test_fn(shape, scale=scale, **kwargs) + + compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn) def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):