Skip to content

Commit 5730c03

Browse files
tvwengerricardoV94
authored andcommitted
Reparameterize GammaRV so beta is not inverted at each call
Also fix wrong JAX implementation of Gamma and Pareto RVs
1 parent 39aa123 commit 5730c03

File tree

5 files changed

+55
-31
lines changed

5 files changed

+55
-31
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,23 +214,22 @@ def sample_fn(rng, size, dtype, *parameters):
214214
return sample_fn
215215

216216

217-
@jax_sample_fn.register(aer.ParetoRV)
218217
@jax_sample_fn.register(aer.GammaRV)
219-
def jax_sample_fn_shape_rate(op):
220-
"""JAX implementation of random variables in the shape-rate family.
218+
@jax_sample_fn.register(aer.ParetoRV)
219+
def jax_sample_fn_shape_scale(op):
220+
"""JAX implementation of random variables in the shape-scale family.
221221
222222
JAX only implements the standard version of random variables in the
223-
shape-rate family. We thus need to rescale the results manually.
223+
shape-scale family. We thus need to rescale the results manually.
224224
225225
"""
226226
name = op.name
227227
jax_op = getattr(jax.random, name)
228228

229-
def sample_fn(rng, size, dtype, *parameters):
229+
def sample_fn(rng, size, dtype, shape, scale):
230230
rng_key = rng["jax_state"]
231231
rng_key, sampling_key = jax.random.split(rng_key, 2)
232-
(shape, rate) = parameters
233-
sample = jax_op(sampling_key, shape, size, dtype) / rate
232+
sample = jax_op(sampling_key, shape, size, dtype) * scale
234233
rng["jax_state"] = rng_key
235234
return (rng, sample)
236235

pytensor/tensor/random/basic.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import warnings
23
from typing import List, Optional, Union
34

45
import numpy as np
@@ -419,7 +420,7 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs):
419420
lognormal = LogNormalRV()
420421

421422

422-
class GammaRV(ScipyRandomVariable):
423+
class GammaRV(RandomVariable):
423424
r"""A gamma continuous random variable.
424425
425426
The probability density function for `gamma` in terms of the shape parameter
@@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable):
443444
dtype = "floatX"
444445
_print_name = ("Gamma", "\\operatorname{Gamma}")
445446

446-
def __call__(self, shape, rate, size=None, **kwargs):
447+
def __call__(self, shape, scale, size=None, **kwargs):
447448
r"""Draw samples from a gamma distribution.
448449
449450
Signature
@@ -455,23 +456,39 @@ def __call__(self, shape, rate, size=None, **kwargs):
455456
----------
456457
shape
457458
The shape :math:`\alpha` of the gamma distribution. Must be positive.
458-
rate
459-
The rate :math:`\beta` of the gamma distribution. Must be positive.
459+
scale
460+
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
460461
size
461462
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
462463
independent, identically distributed random variables are
463464
returned. Default is `None` in which case a single random variable
464465
is returned.
465466
466467
"""
467-
return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
468+
return super().__call__(shape, scale, size=size, **kwargs)
468469

469470
@classmethod
470471
def rng_fn_scipy(cls, rng, shape, scale, size):
471-
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
472+
return rng.gamma(shape, scale=scale, size=size)
473+
472474

475+
_gamma = GammaRV()
476+
477+
478+
def gamma(shape, scale=None, rate=None, **kwargs):
479+
# TODO: Remove helper when rate is deprecated
480+
if rate is not None and scale is not None:
481+
raise ValueError("Cannot specify both rate and scale")
482+
elif rate is None and scale is None:
483+
raise ValueError("Must specify scale")
484+
elif rate is not None:
485+
warnings.warn(
486+
"Gamma rate argument is deprecated and will stop working, use scale instead",
487+
FutureWarning,
488+
)
489+
scale = 1.0 / rate
473490

474-
gamma = GammaRV()
491+
return _gamma(shape, scale, **kwargs)
475492

476493

477494
class ChiSquareRV(RandomVariable):

tests/link/jax/test_random.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def test_random_updates_input_storage_order():
147147
np.array([1.0, 2.0], dtype=np.float64),
148148
),
149149
set_test_value(
150-
at.dscalar(),
151-
np.array(1.0, dtype=np.float64),
150+
at.dvector(),
151+
np.array([0.5, 3.0], dtype=np.float64),
152152
),
153153
],
154154
(2,),
@@ -235,11 +235,15 @@ def test_random_updates_input_storage_order():
235235
set_test_value(
236236
at.dvector(),
237237
np.array([1.0, 2.0], dtype=np.float64),
238-
)
238+
),
239+
set_test_value(
240+
at.dvector(),
241+
np.array([2.0, 10.0], dtype=np.float64),
242+
),
239243
],
240244
(2,),
241245
"pareto",
242-
lambda *args: args,
246+
lambda shape, scale: (shape, 0.0, scale),
243247
),
244248
(
245249
aer.poisson,

tests/link/numba/test_random.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
at.dvector(),
9393
np.array([1.0, 2.0], dtype=np.float64),
9494
),
95+
set_test_value(
96+
at.dvector(),
97+
np.array([2.0, 10.0], dtype=np.float64),
98+
),
9599
],
96100
at.as_tensor([3, 2]),
97101
marks=pytest.mark.xfail(reason="Not implemented"),
@@ -323,8 +327,8 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
323327
np.array([1.0, 2.0], dtype=np.float64),
324328
),
325329
set_test_value(
326-
at.dscalar(),
327-
np.array(1.0, dtype=np.float64),
330+
at.dvector(),
331+
np.array([0.5, 3.0], dtype=np.float64),
328332
),
329333
],
330334
(2,),

tests/tensor/random/test_basic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,11 @@ def test_lognormal_samples(mean, sigma, size):
351351
],
352352
)
353353
def test_gamma_samples(a, b, size):
354-
gamma_test_fn = fixed_scipy_rvs("gamma")
355-
356-
def test_fn(shape, rate, **kwargs):
357-
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
358-
359354
compare_sample_values(
360355
gamma,
361356
a,
362357
b,
363358
size=size,
364-
test_fn=test_fn,
365359
)
366360

367361

@@ -470,18 +464,24 @@ def test_vonmises_samples(mu, kappa, size):
470464

471465

472466
@pytest.mark.parametrize(
473-
"alpha, size",
467+
"alpha, scale, size",
474468
[
475-
(np.array(0.5, dtype=config.floatX), None),
476-
(np.array(0.5, dtype=config.floatX), []),
469+
(np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None),
470+
(np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []),
477471
(
478472
np.full((1, 2), 0.5, dtype=config.floatX),
473+
np.array([0.5, 1.0], dtype=config.floatX),
479474
None,
480475
),
481476
],
482477
)
483-
def test_pareto_samples(alpha, size):
484-
compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
478+
def test_pareto_samples(alpha, scale, size):
479+
pareto_test_fn = fixed_scipy_rvs("pareto")
480+
481+
def test_fn(shape, scale, **kwargs):
482+
return pareto_test_fn(shape, scale=scale, **kwargs)
483+
484+
compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn)
485485

486486

487487
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):

0 commit comments

Comments
 (0)