Skip to content

Commit 32cd3e5

Browse files
committed
revert GammaRV param names
1 parent bede1de commit 32cd3e5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pytensor/tensor/random/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ class GammaRV(ScipyRandomVariable):
443443
dtype = "floatX"
444444
_print_name = ("Gamma", "\\operatorname{Gamma}")
445445

446-
def __call__(self, alpha, beta, size=None, **kwargs):
446+
def __call__(self, shape, rate, size=None, **kwargs):
447447
r"""Draw samples from a gamma distribution.
448448
449449
Signature
@@ -464,11 +464,11 @@ def __call__(self, alpha, beta, size=None, **kwargs):
464464
is returned.
465465
466466
"""
467-
return super().__call__(alpha, beta, size=size, **kwargs)
467+
return super().__call__(shape, rate, size=size, **kwargs)
468468

469469
@classmethod
470-
def rng_fn_scipy(cls, rng, alpha, beta, size):
471-
return stats.gamma.rvs(alpha, scale=1.0 / beta, size=size, random_state=rng)
470+
def rng_fn_scipy(cls, rng, shape, rate, size):
471+
return stats.gamma.rvs(shape, scale=1.0 / rate, size=size, random_state=rng)
472472

473473

474474
gamma = GammaRV()

0 commit comments

Comments
 (0)