Skip to content

Do not invert Gamma rate and fix JAX implementation of Gamma and Pareto #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,20 @@ def sample_fn(rng, size, dtype, *parameters):

@jax_sample_fn.register(aer.ParetoRV)
@jax_sample_fn.register(aer.GammaRV)
def jax_sample_fn_shape_rate(op):
"""JAX implementation of random variables in the shape-rate family.
def jax_sample_fn_shape_scale(op):
"""JAX implementation of random variables in the shape-scale family.

JAX only implements the standard version of random variables in the
shape-rate family. We thus need to rescale the results manually.
shape-scale family. We thus need to rescale the results manually.

"""
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(shape, rate) = parameters
sample = jax_op(sampling_key, shape, size, dtype) / rate
sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)

Expand Down
31 changes: 22 additions & 9 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import warnings
from typing import List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -419,7 +420,7 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs):
lognormal = LogNormalRV()


class GammaRV(ScipyRandomVariable):
class GammaRV(RandomVariable):
r"""A gamma continuous random variable.

The probability density function for `gamma` in terms of the shape parameter
Expand All @@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable):
dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}")

def __call__(self, shape, rate, size=None, **kwargs):
def __call__(self, shape, scale, size=None, **kwargs):
r"""Draw samples from a gamma distribution.

Signature
Expand All @@ -455,23 +456,35 @@ def __call__(self, shape, rate, size=None, **kwargs):
----------
shape
The shape :math:`\alpha` of the gamma distribution. Must be positive.
rate
The rate :math:`\beta` of the gamma distribution. Must be positive.
scale
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.

"""
return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
return super().__call__(shape, scale, size=size, **kwargs)

@classmethod
def rng_fn_scipy(cls, rng, shape, scale, size):
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)

_gamma = GammaRV()


def gamma(shape, rate=None, scale=None, **kwargs):
# TODO: Remove helper when rate is deprecated
if rate is not None and scale is not None:
raise ValueError("Cannot specify both rate and scale")
elif rate is None and scale is None:
raise ValueError("Must specify scale")
elif rate is not None:
warnings.warn(
"Gamma rate argument is deprecated and will stop working, use scale instead",
FutureWarning,
)
scale = 1.0 / rate

gamma = GammaRV()
return _gamma(shape, scale, **kwargs)


class ChiSquareRV(RandomVariable):
Expand Down
10 changes: 5 additions & 5 deletions pytensor/tensor/random/rewriting/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
LogNormalRV,
NegBinomialRV,
WaldRV,
_gamma,
beta,
binomial,
gamma,
normal,
poisson,
uniform,
Expand Down Expand Up @@ -92,29 +92,29 @@ def geometric_from_uniform(fgraph, node):
@node_rewriter([NegBinomialRV])
def negative_binomial_from_gamma_poisson(fgraph, node):
rng, *other_inputs, n, p = node.inputs
next_rng, g = gamma.make_node(rng, *other_inputs, n, p / (1 - p)).outputs
next_rng, g = _gamma.make_node(rng, *other_inputs, n, (1 - p) / p).outputs
next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs
return [next_rng, p]


@node_rewriter([InvGammaRV])
def inverse_gamma_from_gamma(fgraph, node):
*other_inputs, shape, scale = node.inputs
next_rng, g = gamma.make_node(*other_inputs, shape, scale).outputs
next_rng, g = _gamma.make_node(*other_inputs, shape, 1 / scale).outputs
return [next_rng, reciprocal(g)]


@node_rewriter([ChiSquareRV])
def chi_square_from_gamma(fgraph, node):
*other_inputs, df = node.inputs
next_rng, g = gamma.make_node(*other_inputs, df / 2, 1 / 2).outputs
next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
return [next_rng, g]


@node_rewriter([GenGammaRV])
def generalized_gamma_from_gamma(fgraph, node):
*other_inputs, alpha, p, lambd = node.inputs
next_rng, g = gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
next_rng, g = _gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
g = (g ** reciprocal(p)) * lambd
return [next_rng, cast(g, dtype=node.default_output().dtype)]

Expand Down
16 changes: 10 additions & 6 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytensor
import pytensor.tensor as at
import pytensor.tensor.random as aer
import pytensor.tensor.random.basic as aer
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
Expand Down Expand Up @@ -140,15 +140,15 @@ def test_random_updates_input_storage_order():
lambda *args: (0, args[0]),
),
(
aer.gamma,
aer._gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
at.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
Expand Down Expand Up @@ -235,11 +235,15 @@ def test_random_updates_input_storage_order():
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
),
set_test_value(
at.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
],
(2,),
"pareto",
lambda *args: args,
lambda shape, scale: (shape, 0.0, scale),
),
(
aer.poisson,
Expand Down
10 changes: 7 additions & 3 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"),
Expand Down Expand Up @@ -316,15 +320,15 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
lambda *args: args,
),
(
aer.gamma,
aer._gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
at.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
Expand Down
42 changes: 30 additions & 12 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.basic import (
_gamma,
bernoulli,
beta,
betabinom,
Expand Down Expand Up @@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size):
],
)
def test_gamma_samples(a, b, size):
gamma_test_fn = fixed_scipy_rvs("gamma")

def test_fn(shape, rate, **kwargs):
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)

compare_sample_values(
gamma,
_gamma,
a,
b,
size=size,
test_fn=test_fn,
)


def test_gamma_deprecation_wrapper_fn():
out = gamma(5.0, scale=0.5, size=(5,))
assert out.type.shape == (5,)
assert out.owner.inputs[-1].eval() == 0.5

with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"):
out = gamma([5.0, 10.0], 2.0, size=None)
assert out.type.shape == (2,)
assert out.owner.inputs[-1].eval() == 0.5

with pytest.raises(ValueError, match="Must specify scale"):
gamma(5.0)

with pytest.raises(ValueError, match="Cannot specify both rate and scale"):
gamma(5.0, rate=2.0, scale=0.5)


@pytest.mark.parametrize(
"df, size",
[
Expand Down Expand Up @@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size):


@pytest.mark.parametrize(
"alpha, size",
"alpha, scale, size",
[
(np.array(0.5, dtype=config.floatX), None),
(np.array(0.5, dtype=config.floatX), []),
(np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None),
(np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []),
(
np.full((1, 2), 0.5, dtype=config.floatX),
np.array([0.5, 1.0], dtype=config.floatX),
None,
),
],
)
def test_pareto_samples(alpha, size):
compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
def test_pareto_samples(alpha, scale, size):
pareto_test_fn = fixed_scipy_rvs("pareto")

def test_fn(shape, scale, **kwargs):
return pareto_test_fn(shape, scale=scale, **kwargs)

compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn)


def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
Expand Down