Skip to content

Commit 49ad534

Browse files
authored
Seed flaky test TestSamplePPC.test_normal_scalar (#6220)
1 parent a0f849a commit 49ad534

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc/tests/test_sampling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,8 +630,8 @@ def test_normal_scalar(self):
630630
ppc0 = pm.sample_posterior_predictive(
631631
10 * [model.initial_point()], return_inferencedata=False
632632
)
633-
# # deprecated argument is not introduced to fast version [2019/08/20:rpg]
634-
ppc = pm.sample_posterior_predictive(trace, var_names=["a"], return_inferencedata=False)
633+
assert "a" in ppc0
634+
assert len(ppc0["a"][0]) == 10
635635
# test empty ppc
636636
ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False)
637637
assert len(ppc) == 0
@@ -641,7 +641,10 @@ def test_normal_scalar(self):
641641
assert ppc["a"].shape == (nchains, ndraws)
642642

643643
# test default case
644-
idata_ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
644+
random_state = self.get_random_state()
645+
idata_ppc = pm.sample_posterior_predictive(
646+
trace, var_names=["a"], random_seed=random_state
647+
)
645648
ppc = idata_ppc.posterior_predictive
646649
assert "a" in ppc
647650
assert ppc["a"].shape == (nchains, ndraws)

0 commit comments

Comments
 (0)