Skip to content

Commit 96dec49

Browse files
committed
add test
1 parent 854b752 commit 96dec49

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

pymc/variational/approximations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
229229
for j in range(len(trace)):
230230
histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data
231231
i += 1
232-
return dict(histogram=pytensor.shared(pm.floatX(histogram), "histogram"))
232+
return dict(
233+
histogram=pytensor.shared(
234+
pm.floatX(histogram),
235+
"histogram",
236+
shape=histogram.shape,
237+
)
238+
)
233239

234240
def _check_trace(self):
235241
trace = self._kwargs.get("trace", None)

tests/sampling/test_jax.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,17 @@ def test_jax_PosDefMatrix():
8484
[
8585
pytest.param(1),
8686
pytest.param(
87-
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
87+
2,
88+
marks=pytest.mark.skipif(
89+
len(jax.devices()) < 2, reason="not enough devices"
90+
),
8891
),
8992
],
9093
)
9194
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
92-
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
95+
def test_transform_samples(
96+
sampler, postprocessing_backend, chains, postprocessing_vectorize
97+
):
9398
pytensor.config.on_opt_error = "raise"
9499
np.random.seed(13244)
95100

@@ -236,7 +241,9 @@ def test_replace_shared_variables():
236241
x = pytensor.shared(5, name="shared_x")
237242

238243
new_x = _replace_shared_variables([x])
239-
shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
244+
shared_variables = [
245+
var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)
246+
]
240247
assert not shared_variables
241248

242249
x.default_update = x + 1
@@ -263,7 +270,11 @@ def test_get_jaxified_logp():
263270
@pytest.fixture(scope="module")
264271
def model_test_idata_kwargs() -> pm.Model:
265272
with pm.Model(
266-
coords={"x_coord": ["a", "b"], "x_coord2": [1, 2], "z_coord": ["apple", "banana", "orange"]}
273+
coords={
274+
"x_coord": ["a", "b"],
275+
"x_coord2": [1, 2],
276+
"z_coord": ["apple", "banana", "orange"],
277+
}
267278
) as m:
268279
x = pm.Normal("x", shape=(2,), dims=["x_coord"])
269280
_ = pm.Normal("y", x, observed=[0, 0])
@@ -322,23 +333,30 @@ def test_idata_kwargs(
322333

323334
posterior = idata.get("posterior")
324335
assert posterior is not None
325-
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.named_vars_to_dims)["x"][0]
336+
x_dim_expected = idata_kwargs.get(
337+
"dims", model_test_idata_kwargs.named_vars_to_dims
338+
)["x"][0]
326339
assert x_dim_expected is not None
327340
assert posterior["x"].dims[-1] == x_dim_expected
328341

329-
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
342+
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[
343+
x_dim_expected
344+
]
330345
assert x_coords_expected is not None
331346
assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values)
332347

333348
assert posterior["z"].dims[2] == "z_coord"
334349
assert np.all(
335-
posterior["z"].coords["z_coord"].values == np.array(["apple", "banana", "orange"])
350+
posterior["z"].coords["z_coord"].values
351+
== np.array(["apple", "banana", "orange"])
336352
)
337353

338354

339355
def test_get_batched_jittered_initial_points():
340356
with pm.Model() as model:
341-
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))
357+
x = pm.MvNormal(
358+
"x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3))
359+
)
342360

343361
# No jitter
344362
ips = _get_batched_jittered_initial_points(
@@ -347,13 +365,17 @@ def test_get_batched_jittered_initial_points():
347365
assert np.all(ips[0] == 0)
348366

349367
# Single chain
350-
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
368+
ips = _get_batched_jittered_initial_points(
369+
model=model, chains=1, random_seed=1, initvals=None
370+
)
351371

352372
assert ips[0].shape == (2, 3)
353373
assert np.all(ips[0] != 0)
354374

355375
# Multiple chains
356-
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
376+
ips = _get_batched_jittered_initial_points(
377+
model=model, chains=2, random_seed=1, initvals=None
378+
)
357379

358380
assert ips[0].shape == (2, 2, 3)
359381
assert np.all(ips[0][0] != ips[0][1])
@@ -372,7 +394,10 @@ def test_get_batched_jittered_initial_points():
372394
[
373395
pytest.param(1),
374396
pytest.param(
375-
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
397+
2,
398+
marks=pytest.mark.skipif(
399+
len(jax.devices()) < 2, reason="not enough devices"
400+
),
376401
),
377402
],
378403
)
@@ -396,8 +421,12 @@ def test_seeding(chains, random_seed, sampler):
396421
assert all_equal
397422

398423
if chains > 1:
399-
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
400-
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))
424+
assert np.all(
425+
result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1)
426+
)
427+
assert np.all(
428+
result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1)
429+
)
401430

402431

403432
@mock.patch("numpyro.infer.MCMC")
@@ -503,3 +532,10 @@ def test_convergence_warnings(caplog, nuts_sampler):
503532

504533
[record] = caplog.records
505534
assert re.match(r"There were \d+ divergences after tuning", record.message)
535+
536+
537+
@pytest.mark.parametrize("method", ["advi", "fullrank_advi"])
538+
def test_vi_sampling_jax(method):
539+
with pm.Model() as model:
540+
x = pm.Normal("x")
541+
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))

0 commit comments

Comments
 (0)