Description
Edit: I came across issues in the model.dlogp (and sampling) of models using pm.Bound
and TruncatedNormal
while trying to sketch a generic pm.Truncated
class. These problems are described in the messages at the end of the thread.
I was playing around to see if I could implement a generic Truncated
class similar to the Bound
but taking into consideration the extra normalization term. Everything is almost the same as the Bound
except for the logp
method and that it allows for observed values:
class _Truncated(Distribution):
...
def _normalization(self):
if self.lower is not None and self.upper is not None:
lcdf_upper = self._wrapped.logcdf(self.upper)
lcdf_lower = self._wrapped.logcdf(self.lower)
return logdiffexp(lcdf_upper, lcdf_lower)
if self.lower is not None:
return log1mexp(-self._wrapped.logcdf(self.lower))
if self.upper is not None:
return self._wrapped.logcdf(self.upper)
return 0
def logp(self, value):
logp = self._wrapped.logp(value) - self._normalization()
bounds = []
if self.lower is not None:
bounds.append(value >= self.lower)
if self.upper is not None:
bounds.append(value <= self.upper)
if len(bounds) > 0:
return bound(logp, *bounds)
else:
return logp
You can check all the changes in my fork: master...ricardoV94:truncated
Everything seems to work fine, but when I actually try to sample something is definitely off. You can find my Notebook here: https://gist.github.com/ricardoV94/269f07b016a5136f52a1e0238d0ec4e6
First is a manual implementation using Potential:
# create data
np.random.seed(451)
x = np.random.exponential(3, size=5000)
minx=1
maxx=20
obs = x[np.where(~((x<minx) | (x>maxx)))] # remove values outside range
with pm.Model() as manual_model:
λ = pm.Exponential("λ", lam=1/5) # prior exponential with mean of 5
x = pm.Exponential('x', lam=1/λ, observed=obs) # obs exponential with mean of $\lambda$.
exp_dist = pm.Exponential.dist(lam=1/λ) # this is not part of the model, just used to get the logcdf
norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size)
trace_manual= pm.sample(2000, tune=1000, return_inferencedata=True)
az.summary(trace_manual)
And now with Truncated
with pm.Model() as trunc_model:
λ = pm.Exponential("λ", lam=1/5)
x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/λ, observed=obs)
trace_trunc = pm.sample(2000, tune=1000, return_inferencedata=True)
az.summary(trace_trunc)
A lot of divergences and non convergence!
Everything seems to be working well when looking at the check_test_point
:
in[5]
trunc_point = trunc_model.check_test_point(
test_point={'λ_log__': np.log(5)}
)
trunc_point
out[5]
λ_log__ -1.00
x -7936.43
Name: Log-probability of test_point, dtype: float64
in[6]
manual_point = manual_model.check_test_point(
test_point={'λ_log__': np.log(5)}
)
manual_point
out[6]
λ_log__ -1.00
x -8749.23
Name: Log-probability of test_point, dtype: float64
The difference between the manual and the trunc models is exactly the correction term (I assume the potential term is ignored in check_test_point
)
in[7]
model_diff = manual_point - trunc_point
model_diff
out[7]
λ_log__ 0.0
x -812.8
Name: Log-probability of test_point, dtype: float64
in[8]
exp_dist = pm.Exponential.dist(1/5)
norm_term = (-pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size).eval()
norm_term
out[8]
array(812.80311981)
in[9]
trunc_dist = pm.Truncated(pm.Exponential, lower=minx, upper=maxx).dist(lam=1/5)
in[10]
trunc_dist.logp(obs).eval().sum()
out[10]
-7936.430976274112
in[11]
trunc_dist._normalization().eval() * len(obs)
out[11]
-812.8031198123141
I think I must be missing something obvious that happens during the sampling. Why would the model evaluate correctly but fail to sample?