Skip to content

Commit 7b5cd32

Browse files
aloctavodiaJunpeng Lao
authored andcommitted
SMC: reduce autocorrelation posterior samples (#3047)
* reduce autocorrelation final trace * add release note * reduce time smc test * reduce time smc test * remove warning about samples being multiple of chains
1 parent 4dbd754 commit 7b5cd32

File tree

5 files changed

+100
-115
lines changed

5 files changed

+100
-115
lines changed

RELEASE-NOTES.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace`
1515
- Add `Kumaraswamy` distribution
1616
- Rewrite parallel sampling of multiple chains on py3. This resolves
17-
long standing issues when tranferring large traces to the main process,
18-
avoids pickleing issues on UNIX, and allows us to show a progress bar
17+
long standing issues when transferring large traces to the main process,
18+
avoids pickling issues on UNIX, and allows us to show a progress bar
1919
for all chains. If parallel sampling is interrupted, we now return
2020
partial results.
2121
- Add `sample_prior_predictive` which allows for efficient sampling from
2222
the unconditioned model.
23+
- SMC: remove experimental warning, allow sampling using `sample`, reduce autocorrelation from
24+
final trace.
2325

2426
### Fixes
2527

docs/source/notebooks/SMC2_gaussians.ipynb

Lines changed: 20 additions & 21 deletions
Large diffs are not rendered by default.

pymc3/step_methods/smc.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from tqdm import tqdm
1818

1919
import theano
20-
import warnings
2120

2221
from ..model import modelcontext
2322
from ..vartypes import discrete_types
@@ -61,11 +60,16 @@ class SMC(atext.ArrayStepSharedLLK):
6160
out_vars : list
6261
List of output variables for trace recording. If empty unobserved_RVs are taken.
6362
n_steps : int
64-
The number of steps of a Markov Chain. Only works if `tune_interval=0` otherwise it will be
65-
determined adaptively.
63+
The number of steps of a Markov Chain. If `tune_interval > 0` `n_steps` will be used for
64+
the first and last stages, and the number of steps of the intermediate states will be
65+
determined automatically. Otherwise, if `tune_interval = 0`, `n_steps` will be used for
66+
all stages.
6667
scaling : float
6768
Factor applied to the proposal distribution i.e. the step size of the Markov Chain. Only
68-
works if `tune_interval=0` otherwise it will be determined adaptively.
69+
works if `tune_interval=0` otherwise it will be determined automatically
70+
p_acc_rate : float
71+
Probability of not accepting a step. Used to compute `n_steps` when `tune_interval > 0`.
72+
It should be between 0 and 1.
6973
covariance : :class:`numpy.ndarray`
7074
(chains x chains)
7175
Initial Covariance matrix for proposal distribution, if None - identity matrix taken
@@ -80,8 +84,8 @@ class SMC(atext.ArrayStepSharedLLK):
8084
Chain) and the number of steps of a Markov Chain (i.e. `n_steps`).
8185
threshold : float
8286
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
83-
the higher the value of threshold the higher the number of stage. Defaults to 0.5. It should
84-
be between 0 and 1.
87+
the higher the value of threshold the higher the number of stage. Defaults to 0.5.
88+
It should be between 0 and 1.
8589
check_bound : boolean
8690
Check if current sample lies outside of variable definition speeds up computation as the
8791
forward model wont be executed. Default: True
@@ -100,7 +104,7 @@ class SMC(atext.ArrayStepSharedLLK):
100104
"""
101105
default_blocked = True
102106

103-
def __init__(self, vars=None, out_vars=None, n_steps=25, scaling=1.,
107+
def __init__(self, vars=None, out_vars=None, n_steps=25, scaling=1., p_acc_rate=0.001,
104108
covariance=None, likelihood_name='l_like__', proposal_name='MultivariateNormal',
105109
tune_interval=10, threshold=0.5, check_bound=True, model=None, random_seed=-1):
106110

@@ -143,6 +147,8 @@ def __init__(self, vars=None, out_vars=None, n_steps=25, scaling=1.,
143147
self.steps_until_tune = tune_interval
144148
self.population = [model.test_point]
145149
self.n_steps = n_steps
150+
self.n_steps_final = n_steps
151+
self.p_acc_rate = p_acc_rate
146152
self.stage_sample = 0
147153
self.accepted = 0
148154
self.beta = 0
@@ -176,7 +182,8 @@ def astep(self, q0):
176182
# compute n_steps
177183
if self.accepted == 0:
178184
acc_rate = 1 / float(self.tune_interval)
179-
self.n_steps = int(max(1, np.log(0.001) / np.log(1 - acc_rate)))
185+
self.n_steps = 1 + (np.ceil(np.log(self.p_acc_rate) /
186+
np.log(1 - acc_rate)).astype(int))
180187
# Reset counter
181188
self.steps_until_tune = self.tune_interval
182189
self.accepted = 0
@@ -392,22 +399,20 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
392399
progressbar=False, model=None, random_seed=-1, rm_flag=True, **kwargs):
393400
"""Sequential Monte Carlo sampling
394401
395-
Samples the solution space with `chains` of Metropolis chains, where each chain has
396-
`n_steps`=`samples`/`chains` iterations.
402+
Samples the parameter space using a `chains` number of parallel Metropolis chains.
397403
Once finished, the sampled traces are evaluated:
398404
399405
(1) Based on the likelihoods of the final samples, chains are weighted
400406
(2) the weighted covariance of the ensemble is calculated and set as new proposal distribution
401407
(3) the variation in the ensemble is calculated and also the next tempering parameter (`beta`)
402-
(4) New `chains` Markov chains are seeded on the traces with high weight for n_steps iterations
408+
(4) New `chains` Markov chains are seeded on the traces with high weight for a given number of
409+
iterations, the iterations can be computed automatically.
403410
(5) Repeat until `beta` > 1.
404411
405412
Parameters
406413
----------
407414
samples : int
408-
The number of samples to draw from the last stage, i.e. the posterior. Defaults to 1000.
409-
The number of samples should be a multiple of `chains`, otherwise the returned number of
410-
draws will be the lowest closest multiple of `chains`.
415+
The number of samples to draw from the posterior (i.e. last stage). Defaults to 1000.
411416
chains : int
412417
Number of chains used to store samples in backend.
413418
step : :class:`SMC`
@@ -443,11 +448,6 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
443448
`link <https://gji.oxfordjournals.org/content/194/3/1701.full>`__
444449
"""
445450

446-
remainder = samples % chains
447-
if remainder != 0:
448-
warnings.warn("'samples' {} is not a multiple of 'chains' {}. Hence, you will get {} "
449-
"draws from the posterior".format(samples, chains, samples - remainder))
450-
451451
model = modelcontext(model)
452452

453453
if random_seed != -1:
@@ -456,11 +456,6 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
456456
if homepath is None:
457457
raise TypeError('Argument `homepath` should be path to result_directory.')
458458

459-
if 'n_jobs' in kwargs:
460-
cores = kwargs['n_jobs']
461-
warnings.warn(
462-
"The n_jobs argument has been deprecated. Use cores instead.",
463-
DeprecationWarning)
464459
if cores > 1:
465460
if not (chains / float(cores)).is_integer():
466461
raise TypeError('chains / cores has to be a whole number!')
@@ -525,6 +520,7 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
525520

526521
step.population, step.array_population, step.likelihoods = step.select_end_points(
527522
mtrace, chains)
523+
528524
step.beta, step.old_beta, step.weights = step.calc_beta()
529525
#step.beta, step.old_beta, step.weights, sj = step.calc_beta()
530526
#step.sjs *= sj
@@ -558,11 +554,9 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
558554
step.resampling_indexes = step.resample(chains)
559555
step.chain_previous_lpoint = step.get_chain_previous_lpoint(mtrace, chains)
560556

561-
if samples < chains:
562-
samples = 1
563-
else:
564-
samples = int(samples / chains)
565-
sample_args['draws'] = samples
557+
x_chains = nr.randint(0, chains, size=samples)
558+
559+
sample_args['draws'] = step.n_steps_final
566560
sample_args['step'] = step
567561
sample_args['stage_path'] = stage_handler.stage_path(step.stage)
568562
sample_args['x_chains'] = x_chains
@@ -572,7 +566,6 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
572566

573567
#model.marginal_likelihood = step.sjs
574568
return stage_handler.create_result_trace(step.stage,
575-
idxs=range(samples),
576569
step=step,
577570
model=model)
578571

@@ -598,9 +591,9 @@ def _initial_population(samples, chains, model, variables):
598591

599592

600593
def _sample(draws, step=None, start=None, trace=None, chain=0, progressbar=True, model=None,
601-
random_seed=-1):
594+
random_seed=-1, chain_idx=0):
602595

603-
sampling = _iter_sample(draws, step, start, trace, chain, model, random_seed)
596+
sampling = _iter_sample(draws, step, start, trace, chain, model, random_seed, chain_idx)
604597

605598
if progressbar:
606599
sampling = tqdm(sampling, total=draws)
@@ -615,7 +608,7 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, progressbar=True,
615608
return chain
616609

617610

618-
def _iter_sample(draws, step, start=None, trace=None, chain=0, model=None, random_seed=-1):
611+
def _iter_sample(draws, step, start=None, trace=None, chain=0, model=None, random_seed=-1, chain_idx=0):
619612
"""
620613
Modified from :func:`pymc3.sampling._iter_sample` to be more efficient with SMC algorithm.
621614
"""
@@ -637,7 +630,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, model=None, rando
637630

638631
point = pm.Point(start, model=model)
639632
step.chain_index = chain
640-
trace.setup(draws, chain)
633+
trace.setup(draws, chain_idx)
641634
for i in range(draws):
642635
point, out_list = step.step(point)
643636
trace.record(out_list)
@@ -669,6 +662,7 @@ def _iter_parallel_chains(draws, step, stage_path, progressbar, model, n_jobs, c
669662
if x_chains is None:
670663
x_chains = range(chains)
671664

665+
chain_idx = range(0, len(x_chains))
672666
pm._log.info('Initializing chain traces ...')
673667

674668
max_int = np.iinfo(np.int32).max
@@ -683,7 +677,8 @@ def _iter_parallel_chains(draws, step, stage_path, progressbar, model, n_jobs, c
683677
chain,
684678
False,
685679
model,
686-
rseed) for chain, rseed in zip(x_chains, random_seeds)]
680+
rseed,
681+
chain_idx) for chain, rseed, chain_idx in zip(x_chains, random_seeds, chain_idx)]
687682

688683
if draws < 10:
689684
chunksize = n_jobs

pymc3/tests/test_smc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def setup_class(self):
1717
super(TestSMC, self).setup_class()
1818
self.test_folder = mkdtemp(prefix='ATMIP_TEST')
1919

20-
self.samples = 2000
20+
self.samples = 1500
2121
self.chains = 200
2222
n = 4
2323
mu1 = np.ones(n) * (1. / 2)
@@ -47,19 +47,19 @@ def two_gaussians(x):
4747
self.muref = mu1
4848

4949

50-
@pytest.mark.parametrize(['n_jobs', 'stage'], [[1, 0], [2, 6]])
51-
def test_sample_n_core(self, n_jobs, stage):
50+
@pytest.mark.parametrize(['cores', 'stage'], [[1, 0], [2, 6]])
51+
def test_sample_n_core(self, cores, stage):
5252
step_kwargs = {'homepath': self.test_folder, 'stage': stage}
5353
with self.ATMIP_test:
5454
mtrace = pm.sample(draws=self.samples,
5555
chains=self.chains,
56-
cores=n_jobs,
56+
cores=cores,
5757
step = pm.SMC(),
5858
step_kwargs=step_kwargs)
5959

6060
x = mtrace.get_values('X')
6161
mu1d = np.abs(x).mean(axis=0)
62-
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.06)
62+
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.03)
6363
# Scenario IV Ching, J. & Chen, Y. 2007
6464
#assert np.round(np.log(self.ATMIP_test.marginal_likelihood)) == -12.0
6565

@@ -75,7 +75,7 @@ def test_stage_handler(self):
7575
self.chains,
7676
step,
7777
model=self.ATMIP_test)
78-
assert len(corrupted_chains) == 0
78+
assert len(corrupted_chains) == self.chains
7979

8080
rtrace = stage_handler.load_result_trace(model=self.ATMIP_test)
8181

pymc3/tests/test_step.py

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -124,56 +124,46 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
124124
-2.24238542e+00, -1.01648100e+00, -1.01648100e+00, -7.60912865e-01,
125125
1.44384812e+00, 2.07355127e+00, 1.91390340e+00, 1.66559696e+00]),
126126
smc.SMC: np.array(
127-
[-6.93241492e-01, -7.31306682e-01, -6.92278431e-01, -5.38971903e-01,
128-
-3.63864239e-01, -7.81639393e-01, -1.76442235e-01, -6.95182484e-01,
129-
-1.70834637e-01, -6.91297645e-01, -1.13549112e-01, -5.73823456e-01,
130-
-1.31522036e-01, -4.05141577e-01, 7.63062953e-04, -3.11034353e-01,
131-
9.97982970e-02, -1.56511470e-01, 6.68871285e-01, -1.15387717e+00,
132-
3.87738482e-01, -1.19314682e+00, -7.80174876e-03, -1.43622635e+00,
133-
4.33184436e-02, -1.93210174e+00, 4.03565846e-01, -1.93210174e+00,
134-
2.69998078e-01, -1.63371961e+00, 2.04993681e-01, -1.49152889e+00,
135-
-3.33970215e-01, -1.07618355e+00, -7.88408084e-01, -1.07618355e+00,
136-
-5.84146424e-02, -7.60238090e-01, -4.80486087e-01, -1.32865002e-01,
137-
-8.18802072e-01, -2.09393447e-01, -9.40225900e-01, -1.62505785e-01,
138-
-9.69990002e-01, -3.71928783e-01, -9.69990002e-01, -1.79001101e-02,
139-
-7.89632438e-02, -1.79001101e-02, -4.09611000e-01, 7.88749920e-02,
140-
-3.66674846e-01, 2.05150150e-01, -7.94202588e-01, 4.98861576e-01,
141-
-4.80532287e-01, 3.37660876e-01, -1.77206900e-01, -1.61506088e-01,
142-
-7.28509985e-01, 1.73439211e-01, -7.28509985e-01, 3.47487801e-02,
143-
-9.56879974e-01, -1.59855934e-02, -7.24175290e-01, -2.72376580e-01,
144-
-1.01682490e+00, -2.48958662e-01, -1.21307819e+00, 7.17910968e-02,
145-
-1.67818076e+00, 7.89068592e-02, -1.98491030e+00, 3.21678968e-01,
146-
-1.43814668e+00, -3.09112690e-02, -1.43814668e+00, -7.15322991e-01,
147-
-1.84024408e+00, -4.88590415e-01, -1.51673992e+00, -1.10944697e-01,
148-
-1.45569455e+00, 1.24275938e-01, -1.34887321e+00, -1.62491398e-01,
149-
-1.14603705e+00, -4.15902432e-01, -1.11446001e+00, 3.18922609e-03,
150-
-1.34720657e+00, -3.39934926e-01, -1.70915288e+00, -4.31481807e-02,
151-
-1.63971745e+00, 8.21483768e-01, -1.63752467e+00, 7.48596620e-01,
152-
-1.63752467e+00, 9.93980340e-01, -1.90938468e+00, 8.02255046e-01,
153-
-2.02030962e+00, 1.10650559e+00, -2.31344909e+00, -1.52251885e-02,
154-
-2.43257510e+00, -3.68430866e-01, -2.16907649e+00, -9.29511056e-01,
155-
-2.16907649e+00, -9.18969326e-01, -1.74398265e+00, -1.04959839e+00,
156-
-1.66283089e+00, -1.43625696e+00, -1.82055651e+00, -1.43625696e+00,
157-
-2.13665682e+00, -1.40327607e+00, -2.02511907e+00, -1.42094355e+00,
158-
-2.03288017e+00, -1.63685868e+00, -2.21063115e+00, -1.67193325e+00,
159-
-2.33677066e+00, -1.89232856e+00, -2.12598116e+00, -2.09817138e+00,
160-
-2.12598116e+00, -2.28443667e+00, -2.33324231e+00, -2.28443667e+00,
161-
-2.41326691e+00, -2.28443667e+00, -2.34146132e+00, -2.28443667e+00,
162-
-2.18383170e+00, -1.95866523e+00, -2.24356939e+00, -1.85803869e+00,
163-
-1.86676826e+00, -1.79709743e+00, -1.88295378e+00, -2.02148945e+00,
164-
-1.86017722e+00, -2.02148945e+00, -1.86017722e+00, -2.07361656e+00,
165-
-1.86017722e+00, -1.79664076e+00, -2.11641422e+00, -1.57798197e+00,
166-
-2.06769020e+00, -1.26676256e+00, -2.02636926e+00, -1.05742456e+00,
167-
-1.38989604e+00, -9.87916736e-01, -1.54158695e+00, -9.77082655e-01,
168-
-1.17364636e+00, -8.75178531e-01, -1.23600577e+00, -8.72739504e-01,
169-
-1.49246121e+00, -8.72739504e-01, -1.69458186e+00, -1.12213269e+00,
170-
-1.37029442e+00, -1.09619745e+00, -1.28062186e+00, -1.09619745e+00,
171-
-1.39608733e+00, -1.09619745e+00, -1.20813610e+00, -1.48145119e+00,
172-
-9.64693905e-01, -1.18526483e+00, -6.85433431e-01, -1.03724907e+00,
173-
-6.85433431e-01, -1.03724907e+00, -6.32619521e-01, -9.90567503e-01,
174-
-6.31553924e-01, -1.10462261e+00, -8.25572702e-01, -1.17635498e+00,
175-
-8.25572702e-01, -1.17181362e+00, -9.00418925e-01, -1.07920544e+00,
176-
-7.09372067e-01, -7.85138291e-01, -5.89973809e-01, -6.47909120e-01]),
127+
[ 0.94245927, 0.04320349, 0.16616453, -0.42667441, -0.49780471,
128+
0.65384837, -0.25387836, 0.38232654, 0.62490342, -0.21777828,
129+
-0.70756665, 0.9310788 , -0.03941721, -1.20854932, 0.39442244,
130+
0.24306076, -0.98310433, 2.2503327 , 0.54090823, 0.51685018,
131+
-1.32968792, 0.02445827, -0.62052594, -0.28014643, 0.75977904,
132+
-1.20233021, -1.80432242, -0.31547627, -0.33392375, -1.34380736,
133+
1.44597486, -0.15871648, -0.20727832, 0.99115736, 0.3445085 ,
134+
-0.89909578, -0.36983042, 0.16734659, 0.13228431, -0.16786514,
135+
-0.36268027, 0.13369353, -1.28444377, 1.2644179 , -0.47877275,
136+
-0.4411035 , 0.35735115, -1.27425973, -0.43213873, 0.70698702,
137+
-0.7805279 , -1.67705636, -0.10661104, -0.59947856, 0.02693047,
138+
-1.09062222, -0.73592286, -1.56822784, 0.97077952, -0.02149393,
139+
-0.26597767, -0.38710878, -0.09971606, -0.52523725, 1.64000249,
140+
-0.1287883 , 0.09555045, 0.04258323, -0.16771237, 0.79324588,
141+
-0.4439878 , -0.00328163, 0.01267578, 0.31817545, -2.48389205,
142+
-0.43794095, -0.18922707, 0.0042956 , 0.29387263, 0.66119344,
143+
-0.98277349, 0.4039511 , 0.13542066, -0.78467059, -0.24334413,
144+
-0.62519786, -0.79586084, -0.06190844, 0.11355637, 0.66110093,
145+
-2.10383759, 0.48608459, -0.47993295, 0.46791254, 2.01963317,
146+
0.12975299, 1.71604836, -0.09413096, 0.30744711, 0.15079852,
147+
0.31349994, 0.26575959, 0.763656 , -1.81526952, -0.22984443,
148+
1.10531065, 0.26065936, -0.22274362, -0.20853456, 0.32741584,
149+
0.08521911, -1.53866503, 0.28501159, -0.39016642, 0.09505455,
150+
-0.72955337, 1.46268494, 0.56252715, -1.63048738, 1.45718808,
151+
-0.01141763, 0.65826932, 1.8723026 , 0.90744555, 1.40586042,
152+
1.58765986, 0.06792152, -0.71397153, 0.22718523, -1.90281392,
153+
0.58708453, -0.77195137, -0.56979511, -0.6543881 , -1.3711677 ,
154+
-1.72706576, -0.41484231, 0.17460229, 0.74160523, 0.10991525,
155+
0.50297247, 1.04762338, -0.69148618, 1.23291629, -0.49797445,
156+
-0.24914585, 1.44290113, -0.23288806, -1.15495976, 0.63230627,
157+
-1.06229509, 0.18047975, -1.23701009, 0.10994274, -0.81730888,
158+
0.01827404, -0.22824212, -0.76809243, -1.36315643, 0.76097799,
159+
1.51091188, 0.46931765, 1.27261922, 0.98191306, 0.80721561,
160+
1.12844558, 1.86799414, 0.29913787, -1.49977561, 0.7551137 ,
161+
-1.0622067 , -0.46200335, -0.10271276, -0.63924651, 1.56074961,
162+
-0.53611858, -0.23229769, -0.74455411, -2.41567262, -0.96658159,
163+
-0.08795562, 0.08532369, -1.56005584, -0.99356212, 0.32678269,
164+
-0.87012306, 0.83897514, 0.9799229 , -1.27565975, -0.25761179,
165+
0.34968085, -0.95045095, 0.95192797, -1.5101461 , 0.04042998,
166+
-0.91145107, -0.91700215, 0.0825614 , 0.59658604, 0.64933802]),
177167
}
178168

179169
def setup_class(self):
@@ -214,7 +204,6 @@ def check_trace(self, step_method):
214204
progressbar=False,
215205
step=step_method(),
216206
step_kwargs={'homepath': self.temp_dir})
217-
218207
elif step_method.__name__ == 'NUTS':
219208
step = step_method(scaling=model.test_point)
220209
trace = sample(0, tune=n_steps,

0 commit comments

Comments
 (0)