Skip to content

Commit 52a04e7

Browse files
committed
cheaper test using point_logps
1 parent 38007c2 commit 52a04e7

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

pymc_experimental/tests/model/test_model_api.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,19 @@
44
import pymc_experimental as pmx
55

66

7-
def test_sample():
7+
def test_logp():
88
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
9-
and a functional syntax. Checks whether the kwarge `coords` can be passed.
9+
and a functional syntax. Checks whether the kwarg `coords` can be passed.
1010
"""
1111
coords = {"obs": ["a", "b"]}
12-
kwargs = {"draws": 50, "tune": 50, "chains": 1, "random_seed": 1}
1312

1413
with pm.Model(coords=coords) as model:
1514
pm.Normal("x", 0.0, 1.0, dims="obs")
16-
idata = pm.sample(**kwargs)
1715

1816
@pmx.model(coords=coords)
1917
def model_wrapped():
2018
pm.Normal("x", 0.0, 1.0, dims="obs")
2119

2220
mw = model_wrapped()
23-
idata_wrapped = pm.sample(model=mw, **kwargs)
2421

25-
np.testing.assert_array_equal(idata.posterior.x, idata_wrapped.posterior.x)
22+
np.testing.assert_array_equal(model.point_logps(), mw.point_logps())

0 commit comments

Comments
 (0)