@@ -84,12 +84,17 @@ def test_jax_PosDefMatrix():
84
84
[
85
85
pytest .param (1 ),
86
86
pytest .param (
87
- 2 , marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" )
87
+ 2 ,
88
+ marks = pytest .mark .skipif (
89
+ len (jax .devices ()) < 2 , reason = "not enough devices"
90
+ ),
88
91
),
89
92
],
90
93
)
91
94
@pytest .mark .parametrize ("postprocessing_vectorize" , ["scan" , "vmap" ])
92
- def test_transform_samples (sampler , postprocessing_backend , chains , postprocessing_vectorize ):
95
+ def test_transform_samples (
96
+ sampler , postprocessing_backend , chains , postprocessing_vectorize
97
+ ):
93
98
pytensor .config .on_opt_error = "raise"
94
99
np .random .seed (13244 )
95
100
@@ -236,7 +241,9 @@ def test_replace_shared_variables():
236
241
x = pytensor .shared (5 , name = "shared_x" )
237
242
238
243
new_x = _replace_shared_variables ([x ])
239
- shared_variables = [var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )]
244
+ shared_variables = [
245
+ var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )
246
+ ]
240
247
assert not shared_variables
241
248
242
249
x .default_update = x + 1
@@ -263,7 +270,11 @@ def test_get_jaxified_logp():
263
270
@pytest .fixture (scope = "module" )
264
271
def model_test_idata_kwargs () -> pm .Model :
265
272
with pm .Model (
266
- coords = {"x_coord" : ["a" , "b" ], "x_coord2" : [1 , 2 ], "z_coord" : ["apple" , "banana" , "orange" ]}
273
+ coords = {
274
+ "x_coord" : ["a" , "b" ],
275
+ "x_coord2" : [1 , 2 ],
276
+ "z_coord" : ["apple" , "banana" , "orange" ],
277
+ }
267
278
) as m :
268
279
x = pm .Normal ("x" , shape = (2 ,), dims = ["x_coord" ])
269
280
_ = pm .Normal ("y" , x , observed = [0 , 0 ])
@@ -322,23 +333,30 @@ def test_idata_kwargs(
322
333
323
334
posterior = idata .get ("posterior" )
324
335
assert posterior is not None
325
- x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .named_vars_to_dims )["x" ][0 ]
336
+ x_dim_expected = idata_kwargs .get (
337
+ "dims" , model_test_idata_kwargs .named_vars_to_dims
338
+ )["x" ][0 ]
326
339
assert x_dim_expected is not None
327
340
assert posterior ["x" ].dims [- 1 ] == x_dim_expected
328
341
329
- x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
342
+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[
343
+ x_dim_expected
344
+ ]
330
345
assert x_coords_expected is not None
331
346
assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
332
347
333
348
assert posterior ["z" ].dims [2 ] == "z_coord"
334
349
assert np .all (
335
- posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
350
+ posterior ["z" ].coords ["z_coord" ].values
351
+ == np .array (["apple" , "banana" , "orange" ])
336
352
)
337
353
338
354
339
355
def test_get_batched_jittered_initial_points ():
340
356
with pm .Model () as model :
341
- x = pm .MvNormal ("x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 )))
357
+ x = pm .MvNormal (
358
+ "x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 ))
359
+ )
342
360
343
361
# No jitter
344
362
ips = _get_batched_jittered_initial_points (
@@ -347,13 +365,17 @@ def test_get_batched_jittered_initial_points():
347
365
assert np .all (ips [0 ] == 0 )
348
366
349
367
# Single chain
350
- ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
368
+ ips = _get_batched_jittered_initial_points (
369
+ model = model , chains = 1 , random_seed = 1 , initvals = None
370
+ )
351
371
352
372
assert ips [0 ].shape == (2 , 3 )
353
373
assert np .all (ips [0 ] != 0 )
354
374
355
375
# Multiple chains
356
- ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
376
+ ips = _get_batched_jittered_initial_points (
377
+ model = model , chains = 2 , random_seed = 1 , initvals = None
378
+ )
357
379
358
380
assert ips [0 ].shape == (2 , 2 , 3 )
359
381
assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
@@ -372,7 +394,10 @@ def test_get_batched_jittered_initial_points():
372
394
[
373
395
pytest .param (1 ),
374
396
pytest .param (
375
- 2 , marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" )
397
+ 2 ,
398
+ marks = pytest .mark .skipif (
399
+ len (jax .devices ()) < 2 , reason = "not enough devices"
400
+ ),
376
401
),
377
402
],
378
403
)
@@ -396,8 +421,12 @@ def test_seeding(chains, random_seed, sampler):
396
421
assert all_equal
397
422
398
423
if chains > 1 :
399
- assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
400
- assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
424
+ assert np .all (
425
+ result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 )
426
+ )
427
+ assert np .all (
428
+ result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 )
429
+ )
401
430
402
431
403
432
@mock .patch ("numpyro.infer.MCMC" )
@@ -503,3 +532,10 @@ def test_convergence_warnings(caplog, nuts_sampler):
503
532
504
533
[record ] = caplog .records
505
534
assert re .match (r"There were \d+ divergences after tuning" , record .message )
535
+
536
+
537
+ @pytest .mark .parametrize ("method" , ["advi" , "fullrank_advi" ])
538
+ def test_vi_sampling_jax (method ):
539
+ with pm .Model () as model :
540
+ x = pm .Normal ("x" )
541
+ pm .fit (10 , method = method , fn_kwargs = dict (mode = "JAX" ))
0 commit comments