Skip to content

Commit cf1aded

Browse files
Use optimized graph for JAX conversion
1 parent 2c07286 commit cf1aded

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pymc3/sampling_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sample_tfp_nuts(
4747

4848
seed = jax.random.PRNGKey(random_seed)
4949

50-
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
50+
fgraph = model.logp.f.maker.fgraph
5151
fns = jax_funcify(fgraph)
5252
logp_fn_jax = fns[0]
5353

0 commit comments

Comments
 (0)