Skip to content

Commit 3169197

Browse files
tvwengerricardoV94
authored andcommitted
Reparameterize GammaRV so beta is not inverted at each call
Also fix wrong JAX implementation of Gamma and Pareto RVs
1 parent 39aa123 commit 3169197

File tree

6 files changed

+79
-41
lines changed

6 files changed

+79
-41
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,20 @@ def sample_fn(rng, size, dtype, *parameters):
216216

217217
@jax_sample_fn.register(aer.ParetoRV)
218218
@jax_sample_fn.register(aer.GammaRV)
219-
def jax_sample_fn_shape_rate(op):
220-
"""JAX implementation of random variables in the shape-rate family.
219+
def jax_sample_fn_shape_scale(op):
220+
"""JAX implementation of random variables in the shape-scale family.
221221
222222
JAX only implements the standard version of random variables in the
223-
shape-rate family. We thus need to rescale the results manually.
223+
shape-scale family. We thus need to rescale the results manually.
224224
225225
"""
226226
name = op.name
227227
jax_op = getattr(jax.random, name)
228228

229-
def sample_fn(rng, size, dtype, *parameters):
229+
def sample_fn(rng, size, dtype, shape, scale):
230230
rng_key = rng["jax_state"]
231231
rng_key, sampling_key = jax.random.split(rng_key, 2)
232-
(shape, rate) = parameters
233-
sample = jax_op(sampling_key, shape, size, dtype) / rate
232+
sample = jax_op(sampling_key, shape, size, dtype) * scale
234233
rng["jax_state"] = rng_key
235234
return (rng, sample)
236235

pytensor/tensor/random/basic.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import warnings
23
from typing import List, Optional, Union
34

45
import numpy as np
@@ -419,7 +420,7 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs):
419420
lognormal = LogNormalRV()
420421

421422

422-
class GammaRV(ScipyRandomVariable):
423+
class GammaRV(RandomVariable):
423424
r"""A gamma continuous random variable.
424425
425426
The probability density function for `gamma` in terms of the shape parameter
@@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable):
443444
dtype = "floatX"
444445
_print_name = ("Gamma", "\\operatorname{Gamma}")
445446

446-
def __call__(self, shape, rate, size=None, **kwargs):
447+
def __call__(self, shape, scale, size=None, **kwargs):
447448
r"""Draw samples from a gamma distribution.
448449
449450
Signature
@@ -455,23 +456,35 @@ def __call__(self, shape, rate, size=None, **kwargs):
455456
----------
456457
shape
457458
The shape :math:`\alpha` of the gamma distribution. Must be positive.
458-
rate
459-
The rate :math:`\beta` of the gamma distribution. Must be positive.
459+
scale
460+
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
460461
size
461462
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
462463
independent, identically distributed random variables are
463464
returned. Default is `None` in which case a single random variable
464465
is returned.
465466
466467
"""
467-
return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
468+
return super().__call__(shape, scale, size=size, **kwargs)
468469

469-
@classmethod
470-
def rng_fn_scipy(cls, rng, shape, scale, size):
471-
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
472470

471+
_gamma = GammaRV()
472+
473+
474+
def gamma(shape, rate=None, scale=None, **kwargs):
475+
# TODO: Remove helper when rate is deprecated
476+
if rate is not None and scale is not None:
477+
raise ValueError("Cannot specify both rate and scale")
478+
elif rate is None and scale is None:
479+
raise ValueError("Must specify scale")
480+
elif rate is not None:
481+
warnings.warn(
482+
"Gamma rate argument is deprecated and will stop working, use scale instead",
483+
FutureWarning,
484+
)
485+
scale = 1.0 / rate
473486

474-
gamma = GammaRV()
487+
return _gamma(shape, scale, **kwargs)
475488

476489

477490
class ChiSquareRV(RandomVariable):

pytensor/tensor/random/rewriting/jax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
LogNormalRV,
1616
NegBinomialRV,
1717
WaldRV,
18+
_gamma,
1819
beta,
1920
binomial,
20-
gamma,
2121
normal,
2222
poisson,
2323
uniform,
@@ -92,29 +92,29 @@ def geometric_from_uniform(fgraph, node):
9292
@node_rewriter([NegBinomialRV])
9393
def negative_binomial_from_gamma_poisson(fgraph, node):
9494
rng, *other_inputs, n, p = node.inputs
95-
next_rng, g = gamma.make_node(rng, *other_inputs, n, p / (1 - p)).outputs
95+
next_rng, g = _gamma.make_node(rng, *other_inputs, n, (1 - p) / p).outputs
9696
next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs
9797
return [next_rng, p]
9898

9999

100100
@node_rewriter([InvGammaRV])
101101
def inverse_gamma_from_gamma(fgraph, node):
102102
*other_inputs, shape, scale = node.inputs
103-
next_rng, g = gamma.make_node(*other_inputs, shape, scale).outputs
103+
next_rng, g = _gamma.make_node(*other_inputs, shape, 1 / scale).outputs
104104
return [next_rng, reciprocal(g)]
105105

106106

107107
@node_rewriter([ChiSquareRV])
108108
def chi_square_from_gamma(fgraph, node):
109109
*other_inputs, df = node.inputs
110-
next_rng, g = gamma.make_node(*other_inputs, df / 2, 1 / 2).outputs
110+
next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
111111
return [next_rng, g]
112112

113113

114114
@node_rewriter([GenGammaRV])
115115
def generalized_gamma_from_gamma(fgraph, node):
116116
*other_inputs, alpha, p, lambd = node.inputs
117-
next_rng, g = gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
117+
next_rng, g = _gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
118118
g = (g ** reciprocal(p)) * lambd
119119
return [next_rng, cast(g, dtype=node.default_output().dtype)]
120120

tests/link/jax/test_random.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytensor
66
import pytensor.tensor as at
7-
import pytensor.tensor.random as aer
7+
import pytensor.tensor.random.basic as aer
88
from pytensor.compile.function import function
99
from pytensor.compile.sharedvalue import SharedVariable, shared
1010
from pytensor.graph.basic import Constant
@@ -140,15 +140,15 @@ def test_random_updates_input_storage_order():
140140
lambda *args: (0, args[0]),
141141
),
142142
(
143-
aer.gamma,
143+
aer._gamma,
144144
[
145145
set_test_value(
146146
at.dvector(),
147147
np.array([1.0, 2.0], dtype=np.float64),
148148
),
149149
set_test_value(
150-
at.dscalar(),
151-
np.array(1.0, dtype=np.float64),
150+
at.dvector(),
151+
np.array([0.5, 3.0], dtype=np.float64),
152152
),
153153
],
154154
(2,),
@@ -235,11 +235,15 @@ def test_random_updates_input_storage_order():
235235
set_test_value(
236236
at.dvector(),
237237
np.array([1.0, 2.0], dtype=np.float64),
238-
)
238+
),
239+
set_test_value(
240+
at.dvector(),
241+
np.array([2.0, 10.0], dtype=np.float64),
242+
),
239243
],
240244
(2,),
241245
"pareto",
242-
lambda *args: args,
246+
lambda shape, scale: (shape, 0.0, scale),
243247
),
244248
(
245249
aer.poisson,

tests/link/numba/test_random.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
at.dvector(),
9393
np.array([1.0, 2.0], dtype=np.float64),
9494
),
95+
set_test_value(
96+
at.dvector(),
97+
np.array([2.0, 10.0], dtype=np.float64),
98+
),
9599
],
96100
at.as_tensor([3, 2]),
97101
marks=pytest.mark.xfail(reason="Not implemented"),
@@ -316,15 +320,15 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
316320
lambda *args: args,
317321
),
318322
(
319-
aer.gamma,
323+
aer._gamma,
320324
[
321325
set_test_value(
322326
at.dvector(),
323327
np.array([1.0, 2.0], dtype=np.float64),
324328
),
325329
set_test_value(
326-
at.dscalar(),
327-
np.array(1.0, dtype=np.float64),
330+
at.dvector(),
331+
np.array([0.5, 3.0], dtype=np.float64),
328332
),
329333
],
330334
(2,),

tests/tensor/random/test_basic.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.graph.replace import clone_replace
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1919
from pytensor.tensor.random.basic import (
20+
_gamma,
2021
bernoulli,
2122
beta,
2223
betabinom,
@@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size):
351352
],
352353
)
353354
def test_gamma_samples(a, b, size):
354-
gamma_test_fn = fixed_scipy_rvs("gamma")
355-
356-
def test_fn(shape, rate, **kwargs):
357-
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
358-
359355
compare_sample_values(
360-
gamma,
356+
_gamma,
361357
a,
362358
b,
363359
size=size,
364-
test_fn=test_fn,
365360
)
366361

367362

363+
def test_gamma_deprecation_wrapper_fn():
364+
out = gamma(5.0, scale=0.5, size=(5,))
365+
assert out.type.shape == (5,)
366+
assert out.owner.inputs[-1].eval() == 0.5
367+
368+
with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"):
369+
out = gamma([5.0, 10.0], 2.0, size=None)
370+
assert out.type.shape == (2,)
371+
assert out.owner.inputs[-1].eval() == 0.5
372+
373+
with pytest.raises(ValueError, match="Must specify scale"):
374+
gamma(5.0)
375+
376+
with pytest.raises(ValueError, match="Cannot specify both rate and scale"):
377+
gamma(5.0, rate=2.0, scale=0.5)
378+
379+
368380
@pytest.mark.parametrize(
369381
"df, size",
370382
[
@@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size):
470482

471483

472484
@pytest.mark.parametrize(
473-
"alpha, size",
485+
"alpha, scale, size",
474486
[
475-
(np.array(0.5, dtype=config.floatX), None),
476-
(np.array(0.5, dtype=config.floatX), []),
487+
(np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None),
488+
(np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []),
477489
(
478490
np.full((1, 2), 0.5, dtype=config.floatX),
491+
np.array([0.5, 1.0], dtype=config.floatX),
479492
None,
480493
),
481494
],
482495
)
483-
def test_pareto_samples(alpha, size):
484-
compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
496+
def test_pareto_samples(alpha, scale, size):
497+
pareto_test_fn = fixed_scipy_rvs("pareto")
498+
499+
def test_fn(shape, scale, **kwargs):
500+
return pareto_test_fn(shape, scale=scale, **kwargs)
501+
502+
compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn)
485503

486504

487505
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):

0 commit comments

Comments
 (0)