Skip to content

Commit aa552fc

Browse files
committed
fix tests
1 parent 07fac5e commit aa552fc

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/sampling/test_jax.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_get_log_likelihood():
185185
b_true = trace.log_likelihood.b.values
186186
a = np.array(trace.posterior.a)
187187
sigma_log_ = np.log(np.array(trace.posterior.sigma))
188-
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]
188+
b_jax = _get_log_likelihood(model, {"a": a, "sigma_log__": sigma_log_})["b"]
189189

190190
assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))
191191

@@ -215,7 +215,7 @@ def test_get_jaxified_logp():
215215

216216
jax_fn = get_jaxified_logp(m)
217217
# This would underflow if not optimized
218-
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))
218+
assert not np.isinf(jax_fn(dict(x=np.array(5000.0), y=np.array(5000.0))))
219219

220220

221221
@pytest.fixture(scope="module")
@@ -302,19 +302,19 @@ def test_get_batched_jittered_initial_points():
302302
ips = _get_batched_jittered_initial_points(
303303
model=model, chains=1, random_seed=1, initvals=None, jitter=False
304304
)
305-
assert np.all(ips[0] == 0)
305+
assert np.all(ips["x"] == 0)
306306

307307
# Single chain
308308
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
309309

310-
assert ips[0].shape == (2, 3)
311-
assert np.all(ips[0] != 0)
310+
assert ips["x"].shape == (2, 3)
311+
assert np.all(ips["x"] != 0)
312312

313313
# Multiple chains
314314
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
315315

316-
assert ips[0].shape == (2, 2, 3)
317-
assert np.all(ips[0][0] != ips[0][1])
316+
assert ips["x"].shape == (2, 2, 3)
317+
assert np.all(ips["x"][0] != ips["x"][1])
318318

319319

320320
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)