Skip to content

pm.Bound and TruncatedNormal generate wrong gradients #4417

Closed
@ricardoV94

Description

@ricardoV94

Edit: I came across issues in the model.dlogp (and sampling) of models using pm.Bound and TruncatedNormal while trying to sketch a generic pm.Truncated class. These problems are described in the messages at the end of the thread.


I was playing around to see if I could implement a generic Truncated class similar to the Bound but taking into consideration the extra normalization term. Everything is almost the same as the Bound except for the logp method and that it allows for observed values:

class _Truncated(Distribution):
    ...

    def _normalization(self):
        if self.lower is not None and self.upper is not None:
            lcdf_upper = self._wrapped.logcdf(self.upper)
            lcdf_lower = self._wrapped.logcdf(self.lower)
            return logdiffexp(lcdf_upper, lcdf_lower)

        if self.lower is not None:
            return log1mexp(-self._wrapped.logcdf(self.lower))

        if self.upper is not None:
            return self._wrapped.logcdf(self.upper)

        return 0

    def logp(self, value):
        logp = self._wrapped.logp(value) - self._normalization()
        bounds = []
        if self.lower is not None:
            bounds.append(value >= self.lower)
        if self.upper is not None:
            bounds.append(value <= self.upper)
        if len(bounds) > 0:
            return bound(logp, *bounds)
        else:
            return logp

You can check all the changes in my fork: master...ricardoV94:truncated

Everything seems to work fine, but when I actually try to sample something is definitely off. You can find my Notebook here: https://gist.github.com/ricardoV94/269f07b016a5136f52a1e0238d0ec4e6

First is a manual implementation using Potential:

# create data
np.random.seed(451)
x = np.random.exponential(3, size=5000)
minx=1
maxx=20

obs = x[np.where(~((x<minx) | (x>maxx)))] # remove values outside range

with pm.Model() as manual_model:
    λ = pm.Exponential("λ", lam=1/5)  # prior exponential with mean of 5
    x = pm.Exponential('x', lam=1/λ, observed=obs) # obs exponential with mean of $\lambda$.

    exp_dist = pm.Exponential.dist(lam=1/λ) # this is not part of the model, just used to get the logcdf
    norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size)

    trace_manual= pm.sample(2000, tune=1000, return_inferencedata=True)

az.summary(trace_manual)

image

And now with Truncated

with pm.Model() as trunc_model:
    λ = pm.Exponential("λ", lam=1/5)
    x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/λ, observed=obs)

    trace_trunc = pm.sample(2000, tune=1000, return_inferencedata=True)
az.summary(trace_trunc)

image

A lot of divergences and non convergence!

Everything seems to be working well when looking at the check_test_point:

in[5]
trunc_point = trunc_model.check_test_point(
    test_point={'λ_log__': np.log(5)}
)
trunc_point
out[5]
λ_log__      -1.00
x         -7936.43
Name: Log-probability of test_point, dtype: float64

in[6]
manual_point = manual_model.check_test_point(
    test_point={'λ_log__': np.log(5)}
)
manual_point
out[6]
λ_log__      -1.00
x         -8749.23
Name: Log-probability of test_point, dtype: float64

The difference between the manual and the trunc models is exactly the correction term (I assume the potential term is ignored in check_test_point)

in[7]
model_diff = manual_point - trunc_point
model_diff
out[7]
λ_log__      0.0
x         -812.8
Name: Log-probability of test_point, dtype: float64

in[8]
exp_dist = pm.Exponential.dist(1/5)
norm_term = (-pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size).eval()
norm_term
out[8]
array(812.80311981)

in[9]
trunc_dist = pm.Truncated(pm.Exponential, lower=minx, upper=maxx).dist(lam=1/5)

in[10]
trunc_dist.logp(obs).eval().sum()
out[10]
-7936.430976274112

in[11]
trunc_dist._normalization().eval() * len(obs)
out[11]
-812.8031198123141

I think I must be missing something obvious that happens during the sampling. Why would the model evaluate correctly but fail to sample?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions