Skip to content

Commit 994da6c

Browse files
committed
update pytensor version, make xfail more elaborate
1 parent 28cde7d commit 994da6c

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

tests/sampling/test_jax.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,12 @@ def test_jax_PosDefMatrix():
8686
pytest.param(1),
8787
pytest.param(
8888
2,
89-
marks=pytest.mark.skipif(
90-
len(jax.devices()) < 2, reason="not enough devices"
91-
),
89+
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
9290
),
9391
],
9492
)
9593
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
96-
def test_transform_samples(
97-
sampler, postprocessing_backend, chains, postprocessing_vectorize
98-
):
94+
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
9995
pytensor.config.on_opt_error = "raise"
10096
np.random.seed(13244)
10197

@@ -242,9 +238,7 @@ def test_replace_shared_variables():
242238
x = pytensor.shared(5, name="shared_x")
243239

244240
new_x = _replace_shared_variables([x])
245-
shared_variables = [
246-
var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)
247-
]
241+
shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
248242
assert not shared_variables
249243

250244
x.default_update = x + 1
@@ -332,30 +326,23 @@ def test_idata_kwargs(
332326

333327
posterior = idata.get("posterior")
334328
assert posterior is not None
335-
x_dim_expected = idata_kwargs.get(
336-
"dims", model_test_idata_kwargs.named_vars_to_dims
337-
)["x"][0]
329+
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.named_vars_to_dims)["x"][0]
338330
assert x_dim_expected is not None
339331
assert posterior["x"].dims[-1] == x_dim_expected
340332

341-
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[
342-
x_dim_expected
343-
]
333+
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
344334
assert x_coords_expected is not None
345335
assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values)
346336

347337
assert posterior["z"].dims[2] == "z_coord"
348338
assert np.all(
349-
posterior["z"].coords["z_coord"].values
350-
== np.array(["apple", "banana", "orange"])
339+
posterior["z"].coords["z_coord"].values == np.array(["apple", "banana", "orange"])
351340
)
352341

353342

354343
def test_get_batched_jittered_initial_points():
355344
with pm.Model() as model:
356-
x = pm.MvNormal(
357-
"x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3))
358-
)
345+
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))
359346

360347
# No jitter
361348
ips = _get_batched_jittered_initial_points(
@@ -364,17 +351,13 @@ def test_get_batched_jittered_initial_points():
364351
assert np.all(ips[0] == 0)
365352

366353
# Single chain
367-
ips = _get_batched_jittered_initial_points(
368-
model=model, chains=1, random_seed=1, initvals=None
369-
)
354+
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
370355

371356
assert ips[0].shape == (2, 3)
372357
assert np.all(ips[0] != 0)
373358

374359
# Multiple chains
375-
ips = _get_batched_jittered_initial_points(
376-
model=model, chains=2, random_seed=1, initvals=None
377-
)
360+
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
378361

379362
assert ips[0].shape == (2, 2, 3)
380363
assert np.all(ips[0][0] != ips[0][1])
@@ -394,9 +377,7 @@ def test_get_batched_jittered_initial_points():
394377
pytest.param(1),
395378
pytest.param(
396379
2,
397-
marks=pytest.mark.skipif(
398-
len(jax.devices()) < 2, reason="not enough devices"
399-
),
380+
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
400381
),
401382
],
402383
)
@@ -420,12 +401,8 @@ def test_seeding(chains, random_seed, sampler):
420401
assert all_equal
421402

422403
if chains > 1:
423-
assert np.all(
424-
result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1)
425-
)
426-
assert np.all(
427-
result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1)
428-
)
404+
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
405+
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))
429406

430407

431408
@mock.patch("numpyro.infer.MCMC")
@@ -555,7 +532,21 @@ def test_vi_sampling_jax(method):
555532
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
556533

557534

558-
@pytest.mark.xfail(reason="Due to https://github.com/pymc-devs/pytensor/issues/595")
535+
@pytest.mark.xfail(
536+
reason="""
537+
During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
538+
539+
TypeError: The broadcast pattern of the output of scan
540+
(Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
541+
(Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
542+
1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
543+
while it is still variable in the output, or vice-verca. You have to make them consistent,
544+
e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
545+
546+
Instead of fixing this error it makes sense to rework the internals of the variational to utilize
547+
pytensor vectorize instead of scan.
548+
"""
549+
)
559550
def test_vi_sampling_jax_svgd():
560551
with pm.Model():
561552
x = pm.Normal("x")

0 commit comments

Comments
 (0)