Closed
Description
In Stan
, there is an option to write a generated quantities
block for sample generation. Doing the similar in pymc3, however, seems to introduce weird turbulence to the sampler, especially if the generated RV is discrete.
Consider the following simple sample:
import numpy as np
import pymc3 as pm
import theano.tensor as tt
# Data
x = np.array([1.1, 1.9, 2.3, 1.8])
n = len(x)
sigmoid = lambda x: 1/(1 + tt.exp(-x))
np.random.seed(42)
with pm.Model() as model1:
# prior
mu = pm.Normal('mu', mu=0, tau=.001)
sigma = pm.Uniform('sigma', lower=0, upper=10)
# observed
xi = pm.Normal('xi', mu=mu, tau=1/(sigma**2), observed=x)
# inference
trace = pm.sample(1000, njobs=5, tune=500, init=None)
pm.traceplot(trace, varnames=['mu', 'sigma_interval_'])
The result is fairly normal:
Assigned NUTS to mu
Assigned NUTS to sigma_interval_
100%|██████████| 1000/1000 [00:01<00:00, 855.81it/s]
If now I add an addition RV node to the graph:
np.random.seed(42)
with pm.Model() as model1:
# prior
mu = pm.Normal('mu', mu=0, tau=.001)
sigma = pm.Uniform('sigma', lower=0, upper=10)
# observed
xi = pm.Normal('xi', mu=mu, tau=1/(sigma**2), observed=x)
# generation
p = pm.Deterministic('p', sigmoid(mu))
count = pm.Binomial('count', n=10, p=p, shape=10)
# inference
trace = pm.sample(1000, njobs=5, tune=500, init=None)
The output trace is quite unstable and converge much slower:
Assigned NUTS to mu
Assigned NUTS to sigma_interval_
Assigned Metropolis to count
100%|██████████| 1000/1000 [00:06<00:00, 153.00it/s]
if the added RV is continuous, the effect seems to be minimal:
np.random.seed(42)
with pm.Model() as model1:
# prior
mu = pm.Normal('mu', mu=0, tau=.001)
sigma = pm.Uniform('sigma', lower=0, upper=10)
# observed
xi = pm.Normal('xi', mu=mu, tau=1/(sigma**2), observed=x)
# generation
p = pm.Deterministic('p', sigmoid(mu))
xi_ = pm.Normal('xi_', mu=p, shape=10)
# inference
trace = pm.sample(1000, njobs=5, tune=500, init=None)
Assigned NUTS to mu
Assigned NUTS to sigma_interval_
Assigned NUTS to xi_
100%|██████████| 1000/1000 [00:07<00:00, 128.04it/s]
Metadata
Metadata
Assignees
Labels
No labels