Skip to content

Commit c496378

Browse files
author
Junpeng Lao
authored
Improve sample_ppc (#3053)
* Improve sample_ppc No need to specify shape kwarg in ObservedRV for sample_ppc to work Also close #3012 * Fix test and improved also sample_ppc_w * fix test * shape not need in test
1 parent c5f9dd4 commit c496378

File tree

6 files changed

+66
-41
lines changed

6 files changed

+66
-41
lines changed

pymc3/distributions/distribution.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,14 +383,23 @@ def _draw_value(param, point=None, givens=None, size=None):
383383
elif (hasattr(param, 'distribution') and
384384
hasattr(param.distribution, 'random') and
385385
param.distribution.random is not None):
386-
# reset the dist shape for ObservedRV
387386
if hasattr(param, 'observations'):
387+
# shape inspection for ObservedRV
388388
dist_tmp = param.distribution
389389
try:
390390
distshape = param.observations.shape.eval()
391391
except AttributeError:
392392
distshape = param.observations.shape
393+
393394
dist_tmp.shape = distshape
395+
try:
396+
dist_tmp.random(point=point, size=size)
397+
except (ValueError, TypeError):
398+
# reset shape to account for shape changes
399+
# with theano.shared inputs
400+
dist_tmp.shape = np.array([])
401+
val = dist_tmp.random(point=point, size=None)
402+
dist_tmp.shape = val.shape
394403
return dist_tmp.random(point=point, size=size)
395404
else:
396405
return param.distribution.random(point=point, size=size)

pymc3/distributions/mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def random_choice(*args, **kwargs):
159159

160160
w = draw_values([self.w], point=point)[0]
161161
comp_tmp = self._comp_samples(point=point, size=None)
162-
if self.shape.size == 0:
162+
if np.asarray(self.shape).size == 0:
163163
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
164164
else:
165-
distshape = self.shape
165+
distshape = np.asarray(self.shape)
166166
w_samples = generate_samples(random_choice,
167167
w=w,
168168
broadcast_shape=w.shape[:-1] or (1,),

pymc3/sampling.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,8 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
982982
strace = _choose_backend(copy(trace), idx, model=model)
983983
else:
984984
strace = _choose_backend(None, idx, model=model)
985-
# TODO what is this for?
985+
# for user supply start value, fill-in missing value if the supplied
986+
# dict does not contain all parameters
986987
update_start_vals(start[idx - chain], model.test_point, model)
987988
if step.generates_stats and strace.supports_sampler_stats:
988989
strace.setup(draws + tune, idx + chain, step.stats_dtypes)
@@ -1121,18 +1122,26 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
11211122
if progressbar:
11221123
indices = tqdm(indices, total=samples)
11231124

1125+
varnames = [var.name for var in vars]
1126+
1127+
# draw once to inspect the shape
1128+
var_values = list(zip(varnames,
1129+
draw_values(vars, point=model.test_point, size=size)))
1130+
ppc_trace = defaultdict(list)
1131+
for varname, value in var_values:
1132+
ppc_trace[varname] = np.zeros((samples,) + value.shape, value.dtype)
1133+
11241134
try:
1125-
ppc = defaultdict(list)
1126-
for idx in indices:
1135+
for slc, idx in enumerate(indices):
11271136
if nchain > 1:
11281137
chain_idx, point_idx = np.divmod(idx, len_trace)
11291138
param = trace._straces[chain_idx].point(point_idx)
11301139
else:
11311140
param = trace[idx]
11321141

1133-
for var in vars:
1134-
ppc[var.name].append(var.distribution.random(point=param,
1135-
size=size))
1142+
values = draw_values(vars, point=param, size=size)
1143+
for k, v in zip(vars, values):
1144+
ppc_trace[k.name][slc] = v
11361145

11371146
except KeyboardInterrupt:
11381147
pass
@@ -1141,7 +1150,7 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
11411150
if progressbar:
11421151
indices.close()
11431152

1144-
return {k: np.asarray(v) for k, v in ppc.items()}
1153+
return ppc_trace
11451154

11461155

11471156
def sample_ppc_w(traces, samples=None, models=None, weights=None,
@@ -1259,8 +1268,12 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
12591268
for idx in indices:
12601269
param = trace[idx]
12611270
var = variables[idx]
1262-
ppc[var.name].append(var.distribution.random(point=param,
1263-
size=size[idx]))
1271+
# TODO sample_ppc_w is currently only work for model with
1272+
# one observed.
1273+
ppc[var.name].append(draw_values([var],
1274+
point=param,
1275+
size=size[idx]
1276+
)[0])
12641277

12651278
except KeyboardInterrupt:
12661279
pass

pymc3/tests/test_distributions_random.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import theano
1111

1212
import pymc3 as pm
13+
from pymc3.distributions.distribution import draw_values
1314
from .helpers import SeededTest
1415
from .test_distributions import (
1516
build_model, Domain, product, R, Rplus, Rplusbig, Rplusdunif,
@@ -74,7 +75,7 @@ class TestDrawValues(SeededTest):
7475
def test_draw_scalar_parameters(self):
7576
with pm.Model():
7677
y = pm.Normal('y1', mu=0., sd=1.)
77-
mu, tau = pm.distributions.draw_values([y.distribution.mu, y.distribution.tau])
78+
mu, tau = draw_values([y.distribution.mu, y.distribution.tau])
7879
npt.assert_almost_equal(mu, 0)
7980
npt.assert_almost_equal(tau, 1)
8081

@@ -83,7 +84,7 @@ def test_draw_dependencies(self):
8384
x = pm.Normal('x', mu=0., sd=1.)
8485
exp_x = pm.Deterministic('exp_x', pm.math.exp(x))
8586

86-
x, exp_x = pm.distributions.draw_values([x, exp_x])
87+
x, exp_x = draw_values([x, exp_x])
8788
npt.assert_almost_equal(np.exp(x), exp_x)
8889

8990
def test_draw_order(self):
@@ -92,15 +93,15 @@ def test_draw_order(self):
9293
exp_x = pm.Deterministic('exp_x', pm.math.exp(x))
9394

9495
# Need to draw x before drawing log_x
95-
exp_x, x = pm.distributions.draw_values([exp_x, x])
96+
exp_x, x = draw_values([exp_x, x])
9697
npt.assert_almost_equal(np.exp(x), exp_x)
9798

9899
def test_draw_point_replacement(self):
99100
with pm.Model():
100101
mu = pm.Normal('mu', mu=0., tau=1e-3)
101102
sigma = pm.Gamma('sigma', alpha=1., beta=1., transform=None)
102103
y = pm.Normal('y', mu=mu, sd=sigma)
103-
mu2, tau2 = pm.distributions.draw_values([y.distribution.mu, y.distribution.tau],
104+
mu2, tau2 = draw_values([y.distribution.mu, y.distribution.tau],
104105
point={'mu': 5., 'sigma': 2.})
105106
npt.assert_almost_equal(mu2, 5)
106107
npt.assert_almost_equal(tau2, 1 / 2.**2)
@@ -110,7 +111,7 @@ def test_random_sample_returns_nd_array(self):
110111
mu = pm.Normal('mu', mu=0., tau=1e-3)
111112
sigma = pm.Gamma('sigma', alpha=1., beta=1., transform=None)
112113
y = pm.Normal('y', mu=mu, sd=sigma)
113-
mu, tau = pm.distributions.draw_values([y.distribution.mu, y.distribution.tau])
114+
mu, tau = draw_values([y.distribution.mu, y.distribution.tau])
114115
assert isinstance(mu, np.ndarray)
115116
assert isinstance(tau, np.ndarray)
116117

@@ -806,15 +807,15 @@ def test_mixture_random_shape():
806807
like0 = pm.Mixture('like0',
807808
w=w0,
808809
comp_dists=comp0,
809-
shape=y.shape,
810810
observed=y)
811811

812812
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
813813
shape=(20, 2))
814814
w1 = pm.Dirichlet('w1', a=np.ones(2))
815815
like1 = pm.Mixture('like1',
816816
w=w1,
817-
comp_dists=comp1, observed=y)
817+
comp_dists=comp1,
818+
observed=y)
818819

819820
comp2 = pm.Poisson.dist(mu=np.ones(2))
820821
w2 = pm.Dirichlet('w2',
@@ -835,16 +836,12 @@ def test_mixture_random_shape():
835836
comp_dists=comp3,
836837
observed=y)
837838

838-
rand0 = like0.distribution.random(m.test_point, size=100)
839+
rand0, rand1, rand2, rand3 = draw_values([like0, like1, like2, like3],
840+
point=m.test_point,
841+
size=100)
839842
assert rand0.shape == (100, 20)
840-
841-
rand1 = like1.distribution.random(m.test_point, size=100)
842843
assert rand1.shape == (100, 20)
843-
844-
rand2 = like2.distribution.random(m.test_point, size=100)
845844
assert rand2.shape == (100, 20)
846-
847-
rand3 = like3.distribution.random(m.test_point, size=100)
848845
assert rand3.shape == (100, 20)
849846

850847
with m:

pymc3/tests/test_sampling.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def test_choose_backend_shortcut(self):
214214
class TestSamplePPC(SeededTest):
215215
def test_normal_scalar(self):
216216
with pm.Model() as model:
217-
a = pm.Normal('a', mu=0, sd=1)
217+
mu = pm.Normal('mu', 0., 1.)
218+
a = pm.Normal('a', mu=mu, sd=1, observed=0.)
218219
trace = pm.sample()
219220

220221
with model:
@@ -225,7 +226,8 @@ def test_normal_scalar(self):
225226
ppc = pm.sample_ppc(trace, samples=1000, vars=[a])
226227
assert 'a' in ppc
227228
assert ppc['a'].shape == (1000,)
228-
_, pval = stats.kstest(ppc['a'], stats.norm().cdf)
229+
_, pval = stats.kstest(ppc['a'],
230+
stats.norm(loc=0, scale=np.sqrt(2)).cdf)
229231
assert pval > 0.001
230232

231233
with model:
@@ -234,7 +236,9 @@ def test_normal_scalar(self):
234236

235237
def test_normal_vector(self):
236238
with pm.Model() as model:
237-
a = pm.Normal('a', mu=0, sd=1, shape=2)
239+
mu = pm.Normal('mu', 0., 1.)
240+
a = pm.Normal('a', mu=mu, sd=1,
241+
observed=np.array([.5, .2]))
238242
trace = pm.sample()
239243

240244
with model:
@@ -251,16 +255,9 @@ def test_normal_vector(self):
251255
assert ppc['a'].shape == (10, 4, 2)
252256

253257
def test_vector_observed(self):
254-
# This test was initially created to test whether observedRVs
255-
# can assert the shape automatically from the observed data.
256-
# It can make sample_ppc correct for RVs similar to below (i.e.,
257-
# some kind of broadcasting is involved). However, doing so makes
258-
# the application with `theano.shared` array as observed data
259-
# invalid (after the `.set_value` the RV shape could change).
260258
with pm.Model() as model:
261259
mu = pm.Normal('mu', mu=0, sd=1)
262260
a = pm.Normal('a', mu=mu, sd=1,
263-
shape=2, # necessary to make ppc sample correct
264261
observed=np.array([0., 1.]))
265262
trace = pm.sample()
266263

@@ -300,12 +297,12 @@ def test_sample_ppc_w(self):
300297

301298
with pm.Model() as model_0:
302299
mu = pm.Normal('mu', mu=0, sd=1)
303-
y = pm.Normal('y', mu=mu, sd=1, observed=data0, shape=500)
300+
y = pm.Normal('y', mu=mu, sd=1, observed=data0)
304301
trace_0 = pm.sample()
305302

306303
with pm.Model() as model_1:
307304
mu = pm.Normal('mu', mu=0, sd=1, shape=len(data0))
308-
y = pm.Normal('y', mu=mu, sd=1, observed=data0, shape=500)
305+
y = pm.Normal('y', mu=mu, sd=1, observed=data0)
309306
trace_1 = pm.sample()
310307

311308
traces = [trace_0, trace_0]

pymc3/tests/test_shared.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_deterministic(self):
1212
pm.Normal('y', 0, 1, observed=X)
1313
model.logp(model.test_point)
1414

15-
def test_sample_ppc(self):
15+
def test_sample(self):
1616
x = np.random.normal(size=100)
1717
y = x + np.random.normal(scale=1e-2, size=100)
1818

@@ -23,10 +23,19 @@ def test_sample_ppc(self):
2323
with pm.Model() as model:
2424
b = pm.Normal('b', 0., 10.)
2525
pm.Normal('obs', b * x_shared, np.sqrt(1e-2), observed=y)
26+
prior_trace0 = pm.sample_prior_predictive(1000)
2627

2728
trace = pm.sample(1000, init=None, progressbar=False)
29+
pp_trace0 = pm.sample_ppc(trace, 1000)
2830

2931
x_shared.set_value(x_pred)
30-
pp_trace = pm.sample_ppc(trace, 1000)
32+
prior_trace1 = pm.sample_prior_predictive(1000)
33+
pp_trace1 = pm.sample_ppc(trace, 1000)
3134

32-
np.testing.assert_allclose(x_pred, pp_trace['obs'].mean(axis=0), atol=1e-1)
35+
assert prior_trace0['b'].shape == (1000,)
36+
assert prior_trace0['obs'].shape == (1000, 100)
37+
np.testing.assert_allclose(x, pp_trace0['obs'].mean(axis=0), atol=1e-1)
38+
39+
assert prior_trace1['b'].shape == (1000,)
40+
assert prior_trace1['obs'].shape == (1000, 200)
41+
np.testing.assert_allclose(x_pred, pp_trace1['obs'].mean(axis=0), atol=1e-1)

0 commit comments

Comments
 (0)