Skip to content

Update Gamma Distribution to support new pytensor GammaRV reparameterization #6934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.16.1,<2.17
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.16.1,<2.17
- pytensor>=2.17.0,<2.18
- python-graphviz
- scipy>=1.4.1
- typing-extensions>=3.7.4
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.16.1,<2.17
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.16.1,<2.17
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.16.1,<2.17
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def determine_coords(
for dim in dims:
dim_name = dim
# str is applied because dim entries may be None
coords[str(dim_name)] = value[dim].to_numpy()
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
Expand Down
25 changes: 12 additions & 13 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
from pytensor.tensor.math import tanh
from pytensor.tensor.random.basic import (
BetaRV,
_gamma,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
halfcauchy,
halfnormal,
Expand Down Expand Up @@ -2201,16 +2201,17 @@ class Gamma(PositiveContinuous):
sigma : tensor_like of float, optional
Alternative scale parameter (sigma > 0).
"""
rv_op = gamma
# gamma is temporarily a deprecation wrapper in PyTensor
rv_op = _gamma

@classmethod
def dist(cls, alpha=None, beta=None, mu=None, sigma=None, **kwargs):
alpha, beta = cls.get_alpha_beta(alpha, beta, mu, sigma)
alpha = pt.as_tensor_variable(floatX(alpha))
beta = pt.as_tensor_variable(floatX(beta))

# The PyTensor `GammaRV` `Op` will invert the `beta` parameter itself
return super().dist([alpha, beta], **kwargs)
# PyTensor gamma op is parametrized in terms of scale (1/beta)
scale = pt.reciprocal(beta)
return super().dist([alpha, scale], **kwargs)

@classmethod
def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
Expand All @@ -2232,15 +2233,14 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):

return alpha, beta

def moment(rv, size, alpha, inv_beta):
# The PyTensor `GammaRV` `Op` inverts the `beta` parameter itself
mean = alpha * inv_beta
def moment(rv, size, alpha, scale):
mean = alpha * scale
if not rv_size_is_none(size):
mean = pt.full(size, mean)
return mean

def logp(value, alpha, inv_beta):
beta = pt.reciprocal(inv_beta)
def logp(value, alpha, scale):
beta = pt.reciprocal(scale)
res = -pt.gammaln(alpha) + logpow(beta, alpha) - beta * value + logpow(value, alpha - 1)
res = pt.switch(pt.ge(value, 0.0), res, -np.inf)
return check_parameters(
Expand All @@ -2250,14 +2250,13 @@ def logp(value, alpha, inv_beta):
msg="alpha > 0, beta > 0",
)

def logcdf(value, alpha, inv_beta):
beta = pt.reciprocal(inv_beta)
def logcdf(value, alpha, scale):
beta = pt.reciprocal(scale)
res = pt.switch(
pt.lt(value, 0),
-np.inf,
pt.log(pt.gammainc(alpha, beta * value)),
)

return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")


Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.16.1,<2.17
pytensor>=2.17.0,<2.18
pytest-cov>=2.5
pytest>=3.0
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cloudpickle
fastprogress>=0.2.0
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.16.1,<2.17
pytensor>=2.17.0,<2.18
scipy>=1.4.1
typing-extensions>=3.7.4
2 changes: 1 addition & 1 deletion tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,7 +2199,7 @@ class TestHalfCauchy(BaseTestDistributionRandom):
class TestGamma(BaseTestDistributionRandom):
pymc_dist = pm.Gamma
pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
expected_rv_op_params = {"alpha": 2.0, "beta": 1 / 5.0}
expected_rv_op_params = {"shape": 2.0, "scale": 1 / 5.0}
reference_dist_params = {"shape": 2.0, "scale": 1 / 5.0}
reference_dist = seeded_numpy_distribution_builder("gamma")
checks_to_run = [
Expand Down
32 changes: 31 additions & 1 deletion tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor.tensor.random.basic import GeometricRV, NormalRV

from pymc import Censored, Model, draw, find_MAP
from pymc.distributions.continuous import Exponential, TruncatedNormalRV
from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV
from pymc.distributions.shape_utils import change_dist_size
from pymc.distributions.transforms import _default_transform
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
Expand Down Expand Up @@ -392,3 +392,33 @@ def test_truncated_inference():
map = find_MAP(progressbar=False)

assert np.isclose(map["lam"], lam_true, atol=0.1)


def test_truncated_gamma():
# Regression test for https://github.com/pymc-devs/pymc/issues/6931
alpha = 3.0
beta = 3.0
upper = 2.5
x = np.linspace(0.0, upper + 0.5, 100)

gamma_scipy = scipy.stats.gamma(a=alpha, scale=1.0 / beta)
logp_scipy = gamma_scipy.logpdf(x) - gamma_scipy.logcdf(upper)
logp_scipy[x > upper] = -np.inf

gamma_trunc_pymc = Truncated.dist(
Gamma.dist(alpha=alpha, beta=beta),
upper=upper,
)
logp_pymc = logp(gamma_trunc_pymc, x).eval()
np.testing.assert_allclose(
logp_pymc,
logp_scipy,
)

# Changing the size used to invert the beta Gamma parameter again
resized_gamma_trunc_pymc = change_dist_size(gamma_trunc_pymc, new_size=x.shape)
logp_resized_pymc = logp(resized_gamma_trunc_pymc, x).eval()
np.testing.assert_allclose(
logp_resized_pymc,
logp_scipy,
)
Loading