Skip to content

Commit 0defe25

Browse files
committed
add expected fail to remember about svgd
1 parent 5f60008 commit 0defe25

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tests/sampling/test_jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,10 @@ def test_vi_sampling_jax(method):
478478
with pm.Model() as model:
479479
x = pm.Normal("x")
480480
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
481+
482+
483+
@pytest.mark.xfail(reason="Due to https://github.com/pymc-devs/pytensor/issues/595")
484+
def test_vi_sampling_jax_svgd():
485+
with pm.Model():
486+
x = pm.Normal("x")
487+
pm.fit(10, method="svgd", fn_kwargs=dict(mode="JAX"))

0 commit comments

Comments
 (0)