Skip to content

Commit 158cde1

Browse files
farhanreynaldoricardoV94
authored andcommitted
change default parameter to b
1 parent ad5604c commit 158cde1

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

pymc3/distributions/continuous.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3465,8 +3465,8 @@ class RiceRV(RandomVariable):
34653465
_print_name = ("Rice", "\\operatorname{Rice}")
34663466

34673467
@classmethod
3468-
def rng_fn(cls, rng, nu, sigma, size=None):
3469-
return stats.rice.rvs(b=nu / sigma, scale=sigma, size=size, random_state=rng)
3468+
def rng_fn(cls, rng, b, sigma, size=None):
3469+
return stats.rice.rvs(b=b, scale=sigma, size=size, random_state=rng)
34703470

34713471

34723472
rice = RiceRV()
@@ -3539,10 +3539,10 @@ def dist(cls, nu=None, sigma=None, b=None, sd=None, *args, **kwargs):
35393539
sigma = sd
35403540

35413541
nu, b, sigma = cls.get_nu_b(nu, b, sigma)
3542-
nu = at.as_tensor_variable(floatX(nu))
3542+
b = at.as_tensor_variable(floatX(b))
35433543
sigma = at.as_tensor_variable(floatX(sigma))
35443544

3545-
return super().dist([nu, sigma], *args, **kwargs)
3545+
return super().dist([b, sigma], *args, **kwargs)
35463546

35473547
@classmethod
35483548
def get_nu_b(cls, nu, b, sigma):
@@ -3556,7 +3556,7 @@ def get_nu_b(cls, nu, b, sigma):
35563556
return nu, b, sigma
35573557
raise ValueError("Rice distribution must specify either nu" " or b.")
35583558

3559-
def logp(value, nu, sigma):
3559+
def logp(value, b, sigma):
35603560
"""
35613561
Calculate log-probability of Rice distribution at specified value.
35623562
@@ -3570,12 +3570,10 @@ def logp(value, nu, sigma):
35703570
-------
35713571
TensorVariable
35723572
"""
3573-
b = nu / sigma
35743573
x = value / sigma
35753574
return bound(
35763575
at.log(x * at.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sigma),
35773576
sigma >= 0,
3578-
nu >= 0,
35793577
value > 0,
35803578
)
35813579

pymc3/tests/test_distributions_random.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,10 @@ class TestSkewNormalTau(BaseTestDistribution):
530530

531531
class TestRice(BaseTestDistribution):
532532
pymc_dist = pm.Rice
533-
nu = sigma = 2
534-
pymc_dist_params = {"nu": nu, "sigma": sigma}
535-
expected_rv_op_params = {"nu": nu, "sigma": sigma}
536-
reference_dist_params = {"b": nu / sigma, "scale": sigma}
533+
b, sigma = 1, 2
534+
pymc_dist_params = {"b": b, "sigma": sigma}
535+
expected_rv_op_params = {"b": b, "sigma": sigma}
536+
reference_dist_params = {"b": b, "scale": sigma}
537537
reference_dist = seeded_scipy_distribution_builder("rice")
538538
tests_to_run = [
539539
"check_pymc_params_match_rv_op",
@@ -542,11 +542,11 @@ class TestRice(BaseTestDistribution):
542542
]
543543

544544

545-
class TestRiceB(BaseTestDistribution):
545+
class TestRiceNu(BaseTestDistribution):
546546
pymc_dist = pm.Rice
547-
b, sigma = 1, 2
548-
pymc_dist_params = {"b": b, "sigma": sigma}
549-
expected_rv_op_params = {"nu": b * sigma, "sigma": sigma}
547+
nu = sigma = 2
548+
pymc_dist_params = {"nu": nu, "sigma": sigma}
549+
expected_rv_op_params = {"b": nu / sigma, "sigma": sigma}
550550
tests_to_run = ["check_pymc_params_match_rv_op"]
551551

552552

0 commit comments

Comments
 (0)