Skip to content

Commit d944333

Browse files
committed
Add Gamma moment
1 parent cbff818 commit d944333

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,6 +2226,13 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
22262226

22272227
return alpha, beta
22282228

2229+
def get_moment(rv, size, alpha, inv_beta):
2230+
# The Aesara `GammaRV` `Op` inverts the `beta` parameter itself
2231+
mean = alpha * inv_beta
2232+
if not rv_size_is_none(size):
2233+
mean = at.full(size, mean)
2234+
return mean
2235+
22292236
def logcdf(value, alpha, inv_beta):
22302237
"""
22312238
Compute the log of the cumulative distribution function for Gamma distribution

pymc/tests/test_distributions_moments.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Beta,
77
Cauchy,
88
Exponential,
9+
Gamma,
910
HalfCauchy,
1011
HalfNormal,
1112
Kumaraswamy,
@@ -282,3 +283,23 @@ def test_halfcauchy_moment(beta, size, expected):
282283
with Model() as model:
283284
HalfCauchy("x", beta=beta, size=size)
284285
assert_moment_is_expected(model, expected)
286+
287+
288+
@pytest.mark.parametrize(
289+
"alpha, beta, size, expected",
290+
[
291+
(1, 1, None, 1),
292+
(1, 1, 5, np.full(5, 1)),
293+
(np.arange(1, 6), 1, None, np.arange(1, 6)),
294+
(
295+
np.arange(1, 6),
296+
2 * np.arange(1, 6),
297+
(2, 5),
298+
np.full((2, 5), 0.5),
299+
),
300+
],
301+
)
302+
def test_gamma_moment(alpha, beta, size, expected):
303+
with Model() as model:
304+
Gamma("x", alpha=alpha, beta=beta, size=size)
305+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)