@@ -86,16 +86,12 @@ def test_jax_PosDefMatrix():
86
86
pytest .param (1 ),
87
87
pytest .param (
88
88
2 ,
89
- marks = pytest .mark .skipif (
90
- len (jax .devices ()) < 2 , reason = "not enough devices"
91
- ),
89
+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
92
90
),
93
91
],
94
92
)
95
93
@pytest .mark .parametrize ("postprocessing_vectorize" , ["scan" , "vmap" ])
96
- def test_transform_samples (
97
- sampler , postprocessing_backend , chains , postprocessing_vectorize
98
- ):
94
+ def test_transform_samples (sampler , postprocessing_backend , chains , postprocessing_vectorize ):
99
95
pytensor .config .on_opt_error = "raise"
100
96
np .random .seed (13244 )
101
97
@@ -242,9 +238,7 @@ def test_replace_shared_variables():
242
238
x = pytensor .shared (5 , name = "shared_x" )
243
239
244
240
new_x = _replace_shared_variables ([x ])
245
- shared_variables = [
246
- var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )
247
- ]
241
+ shared_variables = [var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )]
248
242
assert not shared_variables
249
243
250
244
x .default_update = x + 1
@@ -332,30 +326,23 @@ def test_idata_kwargs(
332
326
333
327
posterior = idata .get ("posterior" )
334
328
assert posterior is not None
335
- x_dim_expected = idata_kwargs .get (
336
- "dims" , model_test_idata_kwargs .named_vars_to_dims
337
- )["x" ][0 ]
329
+ x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .named_vars_to_dims )["x" ][0 ]
338
330
assert x_dim_expected is not None
339
331
assert posterior ["x" ].dims [- 1 ] == x_dim_expected
340
332
341
- x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[
342
- x_dim_expected
343
- ]
333
+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
344
334
assert x_coords_expected is not None
345
335
assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
346
336
347
337
assert posterior ["z" ].dims [2 ] == "z_coord"
348
338
assert np .all (
349
- posterior ["z" ].coords ["z_coord" ].values
350
- == np .array (["apple" , "banana" , "orange" ])
339
+ posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
351
340
)
352
341
353
342
354
343
def test_get_batched_jittered_initial_points ():
355
344
with pm .Model () as model :
356
- x = pm .MvNormal (
357
- "x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 ))
358
- )
345
+ x = pm .MvNormal ("x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 )))
359
346
360
347
# No jitter
361
348
ips = _get_batched_jittered_initial_points (
@@ -364,17 +351,13 @@ def test_get_batched_jittered_initial_points():
364
351
assert np .all (ips [0 ] == 0 )
365
352
366
353
# Single chain
367
- ips = _get_batched_jittered_initial_points (
368
- model = model , chains = 1 , random_seed = 1 , initvals = None
369
- )
354
+ ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
370
355
371
356
assert ips [0 ].shape == (2 , 3 )
372
357
assert np .all (ips [0 ] != 0 )
373
358
374
359
# Multiple chains
375
- ips = _get_batched_jittered_initial_points (
376
- model = model , chains = 2 , random_seed = 1 , initvals = None
377
- )
360
+ ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
378
361
379
362
assert ips [0 ].shape == (2 , 2 , 3 )
380
363
assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
@@ -394,9 +377,7 @@ def test_get_batched_jittered_initial_points():
394
377
pytest .param (1 ),
395
378
pytest .param (
396
379
2 ,
397
- marks = pytest .mark .skipif (
398
- len (jax .devices ()) < 2 , reason = "not enough devices"
399
- ),
380
+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
400
381
),
401
382
],
402
383
)
@@ -420,12 +401,8 @@ def test_seeding(chains, random_seed, sampler):
420
401
assert all_equal
421
402
422
403
if chains > 1 :
423
- assert np .all (
424
- result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 )
425
- )
426
- assert np .all (
427
- result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 )
428
- )
404
+ assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
405
+ assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
429
406
430
407
431
408
@mock .patch ("numpyro.infer.MCMC" )
@@ -555,7 +532,21 @@ def test_vi_sampling_jax(method):
555
532
pm .fit (10 , method = method , fn_kwargs = dict (mode = "JAX" ))
556
533
557
534
558
- @pytest .mark .xfail (reason = "Due to https://github.com/pymc-devs/pytensor/issues/595" )
535
+ @pytest .mark .xfail (
536
+ reason = """
537
+ During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
538
+
539
+ TypeError: The broadcast pattern of the output of scan
540
+ (Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
541
+ (Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
542
+ 1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
543
+ while it is still variable in the output, or vice-verca. You have to make them consistent,
544
+ e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
545
+
546
+ Instead of fixing this error it makes sense to rework the internals of the variational to utilize
547
+ pytensor vectorize instead of scan.
548
+ """
549
+ )
559
550
def test_vi_sampling_jax_svgd ():
560
551
with pm .Model ():
561
552
x = pm .Normal ("x" )
0 commit comments