Skip to content

Commit 6fdf27c

Browse files
committed
add gamma helper function to deprecate rate parameterization
1 parent bede1de commit 6fdf27c

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

pytensor/tensor/random/basic.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
RandomGeneratorSharedVariable,
1717
RandomStateSharedVariable,
1818
)
19+
import warnings
1920

2021

2122
try:
@@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable):
443444
dtype = "floatX"
444445
_print_name = ("Gamma", "\\operatorname{Gamma}")
445446

446-
def __call__(self, alpha, beta, size=None, **kwargs):
447+
def __call__(self, shape, scale, size=None, **kwargs):
447448
r"""Draw samples from a gamma distribution.
448449
449450
Signature
@@ -455,23 +456,39 @@ def __call__(self, alpha, beta, size=None, **kwargs):
455456
----------
456457
shape
457458
The shape :math:`\alpha` of the gamma distribution. Must be positive.
458-
rate
459-
The rate :math:`\beta` of the gamma distribution. Must be positive.
459+
scale
460+
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
460461
size
461462
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
462463
independent, identically distributed random variables are
463464
returned. Default is `None` in which case a single random variable
464465
is returned.
465466
466467
"""
467-
return super().__call__(alpha, beta, size=size, **kwargs)
468+
return super().__call__(shape, scale, size=size, **kwargs)
468469

469470
@classmethod
470-
def rng_fn_scipy(cls, rng, alpha, beta, size):
471-
return stats.gamma.rvs(alpha, scale=1.0 / beta, size=size, random_state=rng)
471+
def rng_fn_scipy(cls, rng, shape, scale, size):
472+
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
473+
474+
475+
_gamma = GammaRV()
472476

473477

474-
gamma = GammaRV()
478+
def gamma(shape, scale=None, rate=None, **kwargs):
479+
# TODO: Remove helper when rate is deprecated
480+
if rate is not None and scale is not None:
481+
raise ValueError("Can't specify both rate and scale")
482+
elif rate is None and scale is None:
483+
raise ValueError("Must specify scale")
484+
elif rate is not None:
485+
warnings.warn(
486+
"Gamma rate argument is deprecated and will stop working, use scale instead",
487+
FutureWarning,
488+
)
489+
scale = 1.0 / rate
490+
491+
return _gamma(shape, scale, **kwargs)
475492

476493

477494
class ChiSquareRV(RandomVariable):

tests/link/jax/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def test_random_updates_input_storage_order():
147147
np.array([1.0, 2.0], dtype=np.float64),
148148
),
149149
set_test_value(
150-
at.dscalar(),
151-
np.array(1.0, dtype=np.float64),
150+
at.dvector(),
151+
np.array([0.5, 3.0], dtype=np.float64),
152152
),
153153
],
154154
(2,),

tests/link/numba/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
323323
np.array([1.0, 2.0], dtype=np.float64),
324324
),
325325
set_test_value(
326-
at.dscalar(),
327-
np.array(1.0, dtype=np.float64),
326+
at.dvector(),
327+
np.array([0.5, 3.0], dtype=np.float64),
328328
),
329329
],
330330
(2,),

tests/tensor/random/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def test_lognormal_samples(mean, sigma, size):
353353
def test_gamma_samples(a, b, size):
354354
gamma_test_fn = fixed_scipy_rvs("gamma")
355355

356-
def test_fn(shape, rate, **kwargs):
357-
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
356+
def test_fn(shape, scale, **kwargs):
357+
return gamma_test_fn(shape, scale=scale, **kwargs)
358358

359359
compare_sample_values(
360360
gamma,

0 commit comments

Comments
 (0)