Skip to content

Commit 7378c72

Browse files
committed
add expected fail to remember about svgd
1 parent 96dec49 commit 7378c72

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tests/sampling/test_jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,10 @@ def test_vi_sampling_jax(method):
539539
with pm.Model() as model:
540540
x = pm.Normal("x")
541541
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
542+
543+
544+
@pytest.mark.xfail(reason="Due to https://github.com/pymc-devs/pytensor/issues/595")
545+
def test_vi_sampling_jax_svgd():
546+
with pm.Model():
547+
x = pm.Normal("x")
548+
pm.fit(10, method="svgd", fn_kwargs=dict(mode="JAX"))

0 commit comments

Comments
 (0)