Skip to content

Commit 1ed4475

Browse files
committed
Allow logcdf inference in CustomDist
1 parent e2eb26d commit 1ed4475

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

pymc/distributions/distribution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,9 +684,11 @@ def rv_op(
684684
def custom_dist_logp(op, values, size, *params, **kwargs):
685685
return logp(values[0], *params[: len(dist_params)])
686686

687-
@_logcdf.register(rv_type)
688-
def custom_dist_logcdf(op, value, size, *params, **kwargs):
689-
return logcdf(value, *params[: len(dist_params)])
687+
if logcdf is not None:
688+
689+
@_logcdf.register(rv_type)
690+
def custom_dist_logcdf(op, value, size, *params, **kwargs):
691+
return logcdf(value, *params[: len(dist_params)])
690692

691693
@_moment.register(rv_type)
692694
def custom_dist_get_moment(op, rv, size, *params):

tests/distributions/test_distribution.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ def custom_dist(mu, sigma, size):
404404
ip = m.initial_point()
405405
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))
406406

407+
def test_logcdf_inference(self):
408+
def custom_dist(mu, sigma, size):
409+
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
410+
411+
mu = 1
412+
sigma = 1.25
413+
test_value = 0.9
414+
415+
custom_lognormal = CustomDist.dist(mu, sigma, dist=custom_dist)
416+
ref_lognormal = LogNormal.dist(mu, sigma)
417+
418+
np.testing.assert_allclose(
419+
pm.logcdf(custom_lognormal, test_value).eval(),
420+
pm.logcdf(ref_lognormal, test_value).eval(),
421+
)
422+
407423
def test_random_multiple_rngs(self):
408424
def custom_dist(p, sigma, size):
409425
idx = pm.Bernoulli.dist(p=p)

0 commit comments

Comments
 (0)