@@ -185,7 +185,7 @@ def test_get_log_likelihood():
185
185
b_true = trace .log_likelihood .b .values
186
186
a = np .array (trace .posterior .a )
187
187
sigma_log_ = np .log (np .array (trace .posterior .sigma ))
188
- b_jax = _get_log_likelihood (model , [ a , sigma_log_ ] )["b" ]
188
+ b_jax = _get_log_likelihood (model , { "a" : a , "sigma_log__" : sigma_log_ } )["b" ]
189
189
190
190
assert np .allclose (b_jax .reshape (- 1 ), b_true .reshape (- 1 ))
191
191
@@ -215,7 +215,7 @@ def test_get_jaxified_logp():
215
215
216
216
jax_fn = get_jaxified_logp (m )
217
217
# This would underflow if not optimized
218
- assert not np .isinf (jax_fn (( np .array (5000.0 ), np .array (5000.0 ))))
218
+ assert not np .isinf (jax_fn (dict ( x = np .array (5000.0 ), y = np .array (5000.0 ))))
219
219
220
220
221
221
@pytest .fixture (scope = "module" )
@@ -302,19 +302,19 @@ def test_get_batched_jittered_initial_points():
302
302
ips = _get_batched_jittered_initial_points (
303
303
model = model , chains = 1 , random_seed = 1 , initvals = None , jitter = False
304
304
)
305
- assert np .all (ips [0 ] == 0 )
305
+ assert np .all (ips ["x" ] == 0 )
306
306
307
307
# Single chain
308
308
ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
309
309
310
- assert ips [0 ].shape == (2 , 3 )
311
- assert np .all (ips [0 ] != 0 )
310
+ assert ips ["x" ].shape == (2 , 3 )
311
+ assert np .all (ips ["x" ] != 0 )
312
312
313
313
# Multiple chains
314
314
ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
315
315
316
- assert ips [0 ].shape == (2 , 2 , 3 )
317
- assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
316
+ assert ips ["x" ].shape == (2 , 2 , 3 )
317
+ assert np .all (ips ["x" ][0 ] != ips ["x" ][1 ])
318
318
319
319
320
320
@pytest .mark .parametrize (
0 commit comments