Skip to content

Commit b55be35

Browse files
committed
Ignore nan warnings in _interpolated_argcdf
The NaNs are irrelevant because the other result is returned.
1 parent 552ecaf commit b55be35

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3653,12 +3653,23 @@ def _interpolated_argcdf(p, pdf, cdf, x):
36533653
index = np.searchsorted(cdf, p) - 1
36543654
slope = (pdf[index + 1] - pdf[index]) / (x[index + 1] - x[index])
36553655

3656-
return x[index] + np.where(
3657-
np.abs(slope) <= 1e-8,
3658-
np.where(np.abs(pdf[index]) <= 1e-8, np.zeros(index.shape), (p - cdf[index]) / pdf[index]),
3659-
(-pdf[index] + np.sqrt(pdf[index] ** 2 + 2 * slope * (p - cdf[index]))) / slope,
3656+
# First term (constant) of the Taylor expansion around slope = 0
3657+
small_slopes = np.where(
3658+
np.abs(pdf[index]) <= 1e-8, np.zeros(index.shape), (p - cdf[index]) / pdf[index]
36603659
)
36613660

3661+
# This warning happens when we divide by slope = 0: we can ignore it
3662+
# because the other result will be returned
3663+
with warnings.catch_warnings():
3664+
warnings.filterwarnings(
3665+
"ignore", ".*invalid value encountered in true_divide.*", RuntimeWarning
3666+
)
3667+
large_slopes = (
3668+
-pdf[index] + np.sqrt(pdf[index] ** 2 + 2 * slope * (p - cdf[index]))
3669+
) / slope
3670+
3671+
return x[index] + np.where(np.abs(slope) <= 1e-8, small_slopes, large_slopes)
3672+
36623673

36633674
class InterpolatedRV(RandomVariable):
36643675
name = "interpolated"

0 commit comments

Comments
 (0)