We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a708732 commit bede1deCopy full SHA for bede1de
pytensor/tensor/random/basic.py
@@ -443,7 +443,7 @@ class GammaRV(ScipyRandomVariable):
443
dtype = "floatX"
444
_print_name = ("Gamma", "\\operatorname{Gamma}")
445
446
- def __call__(self, shape, rate, size=None, **kwargs):
+ def __call__(self, alpha, beta, size=None, **kwargs):
447
r"""Draw samples from a gamma distribution.
448
449
Signature
@@ -464,11 +464,11 @@ def __call__(self, shape, rate, size=None, **kwargs):
464
is returned.
465
466
"""
467
- return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
+ return super().__call__(alpha, beta, size=size, **kwargs)
468
469
@classmethod
470
- def rng_fn_scipy(cls, rng, shape, scale, size):
471
- return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
+ def rng_fn_scipy(cls, rng, alpha, beta, size):
+ return stats.gamma.rvs(alpha, scale=1.0 / beta, size=size, random_state=rng)
472
473
474
gamma = GammaRV()
0 commit comments