Skip to content

Commit 05a68be

Browse files
committed
Fix
1 parent af91b42 commit 05a68be

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

pymc_experimental/inference/smc/sampling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import jax
2525
import jax.numpy as jnp
2626
import numpy as np
27+
<<<<<<< Updated upstream
28+
=======
29+
from blackjax.smc import extend_params
30+
>>>>>>> Stashed changes
2731

2832
from blackjax.smc.resampling import systematic
2933
from pymc import draw, modelcontext, to_inference_data
@@ -126,17 +130,17 @@ def sample_smc_blackjax(
126130

127131
if kernel == "HMC":
128132
mcmc_kernel = blackjax.mcmc.hmc
129-
mcmc_parameters = dict(
133+
mcmc_parameters = extend_params(dict(
130134
step_size=inner_kernel_params["step_size"],
131135
inverse_mass_matrix=jnp.eye(posterior_dimensions),
132136
num_integration_steps=inner_kernel_params["integration_steps"],
133-
)
137+
))
134138
elif kernel == "NUTS":
135139
mcmc_kernel = blackjax.mcmc.nuts
136-
mcmc_parameters = dict(
140+
mcmc_parameters = extend_params(dict(
137141
step_size=inner_kernel_params["step_size"],
138142
inverse_mass_matrix=jnp.eye(posterior_dimensions),
139-
)
143+
))
140144
else:
141145
raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")
142146

tests/test_blackjax_smc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
import pytensor.tensor as pt
1818
import pytest
1919
import scipy
20+
<<<<<<< Updated upstream:tests/test_blackjax_smc.py
21+
=======
22+
from blackjax.smc import extend_params
23+
>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py
2024

2125
from numpy import dtype
2226
from xarray.core.utils import Frozen
@@ -80,7 +84,11 @@ def fast_model():
8084
("NUTS", False, {"step_size": 0.1}),
8185
],
8286
)
87+
<<<<<<< Updated upstream:tests/test_blackjax_smc.py
8388
@pytest.mark.xfail(reason="Still need to investigate")
89+
=======
90+
91+
>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py
8492
def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params):
8593
"""
8694
When running the two gaussians model

0 commit comments

Comments
 (0)