Skip to content

Commit 4810829

Browse files
ricardoV94twiecki
authored andcommitted
Move TestMetropolisProposal tests to TestMetropolis
1 parent 6750ea7 commit 4810829

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

pymc/tests/test_step.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -174,33 +174,6 @@ def test_step_categorical(self, proposal):
174174
self.check_stat(check, idata, step.__class__.__name__)
175175

176176

177-
class TestMetropolisProposal:
178-
def test_proposal_choice(self):
179-
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
180-
_, model, _ = mv_simple()
181-
with model:
182-
initial_point = model.initial_point()
183-
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)
184-
185-
s = np.ones(initial_point_size)
186-
sampler = Metropolis(S=s)
187-
assert isinstance(sampler.proposal_dist, NormalProposal)
188-
s = np.diag(s)
189-
sampler = Metropolis(S=s)
190-
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
191-
s[0, 0] = -s[0, 0]
192-
with pytest.raises(np.linalg.LinAlgError):
193-
sampler = Metropolis(S=s)
194-
195-
def test_mv_proposal(self):
196-
np.random.seed(42)
197-
cov = np.random.randn(5, 5)
198-
cov = cov.dot(cov.T)
199-
prop = MultivariateNormalProposal(cov)
200-
samples = np.array([prop() for _ in range(10000)])
201-
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2)
202-
203-
204177
class TestCompoundStep:
205178
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)
206179

@@ -383,6 +356,31 @@ def test_parallelized_chains_are_random(self):
383356

384357

385358
class TestMetropolis:
359+
def test_proposal_choice(self):
360+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
361+
_, model, _ = mv_simple()
362+
with model:
363+
initial_point = model.initial_point()
364+
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)
365+
366+
s = np.ones(initial_point_size)
367+
sampler = Metropolis(S=s)
368+
assert isinstance(sampler.proposal_dist, NormalProposal)
369+
s = np.diag(s)
370+
sampler = Metropolis(S=s)
371+
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
372+
s[0, 0] = -s[0, 0]
373+
with pytest.raises(np.linalg.LinAlgError):
374+
sampler = Metropolis(S=s)
375+
376+
def test_mv_proposal(self):
377+
np.random.seed(42)
378+
cov = np.random.randn(5, 5)
379+
cov = cov.dot(cov.T)
380+
prop = MultivariateNormalProposal(cov)
381+
samples = np.array([prop() for _ in range(10000)])
382+
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2)
383+
386384
def test_tuning_reset(self):
387385
"""Re-use of the step method instance with cores=1 must not leak tuning information between chains."""
388386
with Model() as pmodel:

0 commit comments

Comments
 (0)