Skip to content

Commit 1656c38

Browse files
authored
update to use pm.Truncated (#423)
1 parent 9fad19c commit 1656c38

File tree

2 files changed

+156
-118
lines changed

2 files changed

+156
-118
lines changed

examples/generalized_linear_models/GLM-truncated-censored-regression.ipynb

Lines changed: 147 additions & 103 deletions
Large diffs are not rendered by default.

myst_nbs/generalized_linear_models/GLM-truncated-censored-regression.myst.md

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ kernelspec:
1414
(GLM-truncated-censored-regression)=
1515
# Bayesian regression with truncated or censored data
1616

17-
:::{post} January, 2022
17+
:::{post} September, 2022
1818
:tags: censored, censoring, generalized linear model, regression, truncated, truncation
1919
:category: beginner
2020
:author: Benjamin T. Vincent
@@ -126,7 +126,7 @@ def linear_regression(x, y):
126126
with pm.Model() as model:
127127
slope = pm.Normal("slope", mu=0, sigma=1)
128128
intercept = pm.Normal("intercept", mu=0, sigma=1)
129-
σ = pm.HalfNormal("σ", sd=1)
129+
σ = pm.HalfNormal("σ", sigma=1)
130130
pm.Normal("obs", mu=slope * x + intercept, sigma=σ, observed=y)
131131
132132
return model
@@ -163,7 +163,7 @@ az.plot_posterior(trunc_linear_fit, var_names=["slope"], ref_val=slope, ax=ax[0]
163163
ax[0].set(title="Linear regression\n(truncated data)", xlabel="slope")
164164
165165
az.plot_posterior(cens_linear_fit, var_names=["slope"], ref_val=slope, ax=ax[1])
166-
ax[1].set(title="Linear regression\n(censored data)", xlabel="slope")
166+
ax[1].set(title="Linear regression\n(censored data)", xlabel="slope");
167167
```
168168

169169
To appreciate the extent of the problem (for this dataset) we can visualise the posterior predictive fits alongside the data.
@@ -217,15 +217,8 @@ def truncated_regression(x, y, bounds):
217217
slope = pm.Normal("slope", mu=0, sigma=1)
218218
intercept = pm.Normal("intercept", mu=0, sigma=1)
219219
σ = pm.HalfNormal("σ", sigma=1)
220-
221-
pm.TruncatedNormal(
222-
"obs",
223-
mu=slope * x + intercept,
224-
sigma=σ,
225-
observed=y,
226-
lower=bounds[0],
227-
upper=bounds[1],
228-
)
220+
normal_dist = pm.Normal.dist(mu=slope * x + intercept, sigma=σ)
221+
pm.Truncated("obs", normal_dist, lower=bounds[0], upper=bounds[1], observed=y)
229222
return model
230223
```
231224

@@ -260,7 +253,7 @@ def censored_regression(x, y, bounds):
260253
with pm.Model() as model:
261254
slope = pm.Normal("slope", mu=0, sigma=1)
262255
intercept = pm.Normal("intercept", mu=0, sigma=1)
263-
σ = pm.HalfNormal("σ", sd=1)
256+
σ = pm.HalfNormal("σ", sigma=1)
264257
y_latent = pm.Normal.dist(mu=slope * x + intercept, sigma=σ)
265258
obs = pm.Censored("obs", y_latent, lower=bounds[0], upper=bounds[1], observed=y)
266259
@@ -284,8 +277,8 @@ with pm.Model() as m:
284277
with pm.Model() as m_censored:
285278
pm.Censored("y", pm.Normal.dist(0, 2), lower=-1.0, upper=None)
286279
287-
logp_fn = m.logp
288-
logp_censored_fn = m_censored.logp
280+
logp_fn = m.compile_logp()
281+
logp_censored_fn = m_censored.compile_logp()
289282
290283
xi = np.hstack((np.linspace(-6, -1.01), [-1.0], np.linspace(-0.99, 6)))
291284
@@ -368,6 +361,7 @@ When looking into this topic, I found that most of the material out there focuse
368361
## Authors
369362
* Authored by [Benjamin T. Vincent](https://github.com/drbenvincent) in May 2021
370363
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in January 2022
364+
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in September 2022
371365

372366
+++
373367

0 commit comments

Comments
 (0)