Skip to content

Commit 1e7d91f

Browse files
ricardoV94twiecki
authored andcommitted
Make Metropolis cope better with multiple dimensions
Metropolis now updates each dimension sequentially and tunes a proposal scale parameter per dimension
1 parent 4810829 commit 1e7d91f

File tree

3 files changed

+128
-48
lines changed

3 files changed

+128
-48
lines changed

pymc/step_methods/metropolis.py

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ def __init__(
168168
vars = model.value_vars
169169
else:
170170
vars = [model.rvs_to_values.get(var, var) for var in vars]
171+
171172
vars = pm.inputvars(vars)
172173

174+
initial_values_shape = [initial_values[v.name].shape for v in vars]
173175
if S is None:
174-
S = np.ones(sum(initial_values[v.name].size for v in vars))
176+
S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape)))
175177

176178
if proposal_dist is not None:
177179
self.proposal_dist = proposal_dist(S)
@@ -186,7 +188,6 @@ def __init__(
186188
self.tune = tune
187189
self.tune_interval = tune_interval
188190
self.steps_until_tune = tune_interval
189-
self.accepted = 0
190191

191192
# Determine type of variables
192193
self.discrete = np.concatenate(
@@ -195,11 +196,33 @@ def __init__(
195196
self.any_discrete = self.discrete.any()
196197
self.all_discrete = self.discrete.all()
197198

198-
# remember initial settings before tuning so they can be reset
199-
self._untuned_settings = dict(
200-
scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted
199+
# Metropolis will try to handle one batched dimension at a time This, however,
200+
# is not safe for discrete multivariate distributions (looking at you Multinomial),
201+
# due to high dependency among the support dimensions. For continuous multivariate
202+
# distributions we assume they are being transformed in a way that makes each
203+
# dimension semi-independent.
204+
is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == ()
205+
self.elemwise_update = not (
206+
is_scalar
207+
or (
208+
self.any_discrete
209+
and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars)
210+
> 0
211+
)
201212
)
213+
if self.elemwise_update:
214+
dims = int(sum(np.prod(ivs) for ivs in initial_values_shape))
215+
else:
216+
dims = 1
217+
self.enum_dims = np.arange(dims, dtype=int)
218+
self.accept_rate_iter = np.zeros(dims, dtype=float)
219+
self.accepted_iter = np.zeros(dims, dtype=bool)
220+
self.accepted_sum = np.zeros(dims, dtype=int)
202221

222+
# remember initial settings before tuning so they can be reset
223+
self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval)
224+
225+
# TODO: This is not being used when compiling the logp function!
203226
self.mode = mode
204227

205228
shared = pm.make_shared_replacements(initial_values, vars, model)
@@ -210,6 +233,7 @@ def reset_tuning(self):
210233
"""Resets the tuned sampler parameters to their initial values."""
211234
for attr, initial_value in self._untuned_settings.items():
212235
setattr(self, attr, initial_value)
236+
self.accepted_sum[:] = 0
213237
return
214238

215239
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
@@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
219243

220244
if not self.steps_until_tune and self.tune:
221245
# Tune scaling parameter
222-
self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval))
246+
self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval))
223247
# Reset counter
224248
self.steps_until_tune = self.tune_interval
225-
self.accepted = 0
249+
self.accepted_sum[:] = 0
226250

227251
delta = self.proposal_dist() * self.scaling
228252

@@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
237261
else:
238262
q = floatX(q0 + delta)
239263

240-
accept = self.delta_logp(q, q0)
241-
q_new, accepted = metrop_select(accept, q, q0)
242-
243-
self.accepted += accepted
264+
if self.elemwise_update:
265+
q_temp = q0.copy()
266+
# Shuffle order of updates (probably we don't need to do this in every step)
267+
np.random.shuffle(self.enum_dims)
268+
for i in self.enum_dims:
269+
q_temp[i] = q[i]
270+
accept_rate_i = self.delta_logp(q_temp, q0)
271+
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0)
272+
q_temp[i] = q_temp_[i]
273+
self.accept_rate_iter[i] = accept_rate_i
274+
self.accepted_iter[i] = accepted_i
275+
self.accepted_sum[i] += accepted_i
276+
q = q_temp
277+
else:
278+
accept_rate = self.delta_logp(q, q0)
279+
q, accepted = metrop_select(accept_rate, q, q0)
280+
self.accept_rate_iter = accept_rate
281+
self.accepted_iter = accepted
282+
self.accepted_sum += accepted
244283

245284
self.steps_until_tune -= 1
246285

247286
stats = {
248287
"tune": self.tune,
249-
"scaling": self.scaling,
250-
"accept": np.exp(accept),
251-
"accepted": accepted,
288+
"scaling": np.mean(self.scaling),
289+
"accept": np.mean(np.exp(self.accept_rate_iter)),
290+
"accepted": np.mean(self.accepted_iter),
252291
}
253292

254-
q_new = RaveledVars(q_new, point_map_info)
255-
256-
return q_new, [stats]
293+
return RaveledVars(q, point_map_info), [stats]
257294

258295
@staticmethod
259296
def competence(var, has_grad):
@@ -275,26 +312,38 @@ def tune(scale, acc_rate):
275312
>0.95 x 10
276313
277314
"""
278-
if acc_rate < 0.001:
315+
return scale * np.where(
316+
acc_rate < 0.001,
279317
# reduce by 90 percent
280-
return scale * 0.1
281-
elif acc_rate < 0.05:
282-
# reduce by 50 percent
283-
return scale * 0.5
284-
elif acc_rate < 0.2:
285-
# reduce by ten percent
286-
return scale * 0.9
287-
elif acc_rate > 0.95:
288-
# increase by factor of ten
289-
return scale * 10.0
290-
elif acc_rate > 0.75:
291-
# increase by double
292-
return scale * 2.0
293-
elif acc_rate > 0.5:
294-
# increase by ten percent
295-
return scale * 1.1
296-
297-
return scale
318+
0.1,
319+
np.where(
320+
acc_rate < 0.05,
321+
# reduce by 50 percent
322+
0.5,
323+
np.where(
324+
acc_rate < 0.2,
325+
# reduce by ten percent
326+
0.9,
327+
np.where(
328+
acc_rate > 0.95,
329+
# increase by factor of ten
330+
10.0,
331+
np.where(
332+
acc_rate > 0.75,
333+
# increase by double
334+
2.0,
335+
np.where(
336+
acc_rate > 0.5,
337+
# increase by ten percent
338+
1.1,
339+
# Do not change
340+
1.0,
341+
),
342+
),
343+
),
344+
),
345+
),
346+
)
298347

299348

300349
class BinaryMetropolis(ArrayStep):

pymc/step_methods/mlda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,22 +787,22 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
787787
if isinstance(self.step_method_below, MLDA):
788788
self.base_tuning_stats = self.step_method_below.base_tuning_stats
789789
elif isinstance(self.step_method_below, MetropolisMLDA):
790-
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling})
790+
self.base_tuning_stats.append({"base_scaling": np.mean(self.step_method_below.scaling)})
791791
elif isinstance(self.step_method_below, DEMetropolisZMLDA):
792792
self.base_tuning_stats.append(
793793
{
794-
"base_scaling": self.step_method_below.scaling,
794+
"base_scaling": np.mean(self.step_method_below.scaling),
795795
"base_lambda": self.step_method_below.lamb,
796796
}
797797
)
798798
elif isinstance(self.step_method_below, CompoundStep):
799799
# Below method is CompoundStep
800800
for method in self.step_method_below.methods:
801801
if isinstance(method, MetropolisMLDA):
802-
self.base_tuning_stats.append({"base_scaling": method.scaling})
802+
self.base_tuning_stats.append({"base_scaling": np.mean(method.scaling)})
803803
elif isinstance(method, DEMetropolisZMLDA):
804804
self.base_tuning_stats.append(
805-
{"base_scaling": method.scaling, "base_lambda": method.lamb}
805+
{"base_scaling": np.mean(method.scaling), "base_lambda": method.lamb}
806806
)
807807

808808
return q_new, [stats] + self.base_tuning_stats

pymc/tests/test_step.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
Beta,
3636
Binomial,
3737
Categorical,
38+
Dirichlet,
3839
HalfNormal,
40+
Multinomial,
3941
MvNormal,
4042
Normal,
4143
)
@@ -401,6 +403,40 @@ def test_tuning_reset(self):
401403
assert tuned != 0.1
402404
np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned)
403405

406+
@pytest.mark.parametrize(
407+
"batched_dist",
408+
(
409+
Binomial.dist(n=5, p=0.9), # scalar case
410+
Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)),
411+
Binomial.dist(
412+
n=(np.arange(20) + 1)[::-1],
413+
p=np.linspace(0.1, 0.9, 20),
414+
shape=(
415+
2,
416+
20,
417+
),
418+
),
419+
Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)),
420+
Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)),
421+
),
422+
)
423+
def test_elemwise_update(self, batched_dist):
424+
with Model() as m:
425+
m.register_rv(batched_dist, name="batched_dist")
426+
step = pm.Metropolis([batched_dist])
427+
assert step.elemwise_update == (batched_dist.ndim > 0)
428+
trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428)
429+
430+
assert az.rhat(trace).max()["batched_dist"].values < 1.1
431+
assert az.ess(trace).min()["batched_dist"].values > 50
432+
433+
def test_multinomial_no_elemwise_update(self):
434+
with Model() as m:
435+
batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4))
436+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
437+
step = pm.Metropolis([batched_dist])
438+
assert not step.elemwise_update
439+
404440

405441
class TestDEMetropolisZ:
406442
def test_tuning_lambda_sequential(self):
@@ -1215,8 +1251,6 @@ def perform(self, node, inputs, outputs):
12151251
mout = []
12161252
coarse_models = []
12171253

1218-
rng = np.random.RandomState(seed)
1219-
12201254
with Model() as coarse_model_0:
12211255
if aesara.config.floatX == "float32":
12221256
Q = Data("Q", np.float32(0.0))
@@ -1234,8 +1268,6 @@ def perform(self, node, inputs, outputs):
12341268

12351269
coarse_models.append(coarse_model_0)
12361270

1237-
rng = np.random.RandomState(seed)
1238-
12391271
with Model() as coarse_model_1:
12401272
if aesara.config.floatX == "float32":
12411273
Q = Data("Q", np.float32(0.0))
@@ -1253,8 +1285,6 @@ def perform(self, node, inputs, outputs):
12531285

12541286
coarse_models.append(coarse_model_1)
12551287

1256-
rng = np.random.RandomState(seed)
1257-
12581288
with Model() as model:
12591289
if aesara.config.floatX == "float32":
12601290
Q = Data("Q", np.float32(0.0))
@@ -1312,8 +1342,9 @@ def perform(self, node, inputs, outputs):
13121342
(nchains, ndraws * nsub)
13131343
)
13141344
Q_2_1 = np.concatenate(trace.get_sampler_stats("Q_2_1")).reshape((nchains, ndraws))
1315-
assert Q_1_0.mean(axis=1) == 0.0
1316-
assert Q_2_1.mean(axis=1) == 0.0
1345+
# This used to be a scrict zero equality!
1346+
assert np.isclose(Q_1_0.mean(axis=1), 0.0, atol=1e-4)
1347+
assert np.isclose(Q_2_1.mean(axis=1), 0.0, atol=1e-4)
13171348

13181349

13191350
class TestRVsAssignmentSteps:

0 commit comments

Comments
 (0)