Skip to content

Commit 652bbde

Browse files
committed
fix
1 parent 05a68be commit 652bbde

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

pymc_experimental/inference/smc/sampling.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
import jax
2525
import jax.numpy as jnp
2626
import numpy as np
27-
<<<<<<< Updated upstream
28-
=======
2927
from blackjax.smc import extend_params
30-
>>>>>>> Stashed changes
31-
3228
from blackjax.smc.resampling import systematic
3329
from pymc import draw, modelcontext, to_inference_data
3430
from pymc.backends import NDArray
@@ -133,14 +129,14 @@ def sample_smc_blackjax(
133129
mcmc_parameters = extend_params(dict(
134130
step_size=inner_kernel_params["step_size"],
135131
inverse_mass_matrix=jnp.eye(posterior_dimensions),
136-
num_integration_steps=inner_kernel_params["integration_steps"],
137-
))
132+
num_integration_steps=inner_kernel_params["integration_steps"])
133+
)
138134
elif kernel == "NUTS":
139135
mcmc_kernel = blackjax.mcmc.nuts
140136
mcmc_parameters = extend_params(dict(
141137
step_size=inner_kernel_params["step_size"],
142-
inverse_mass_matrix=jnp.eye(posterior_dimensions),
143-
))
138+
inverse_mass_matrix=jnp.eye(posterior_dimensions))
139+
)
144140
else:
145141
raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")
146142

tests/test_blackjax_smc.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
import pytensor.tensor as pt
1818
import pytest
1919
import scipy
20-
<<<<<<< Updated upstream:tests/test_blackjax_smc.py
21-
=======
22-
from blackjax.smc import extend_params
23-
>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py
2420

2521
from numpy import dtype
2622
from xarray.core.utils import Frozen
@@ -84,11 +80,8 @@ def fast_model():
8480
("NUTS", False, {"step_size": 0.1}),
8581
],
8682
)
87-
<<<<<<< Updated upstream:tests/test_blackjax_smc.py
88-
@pytest.mark.xfail(reason="Still need to investigate")
89-
=======
9083

91-
>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py
84+
9285
def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params):
9386
"""
9487
When running the two gaussians model

0 commit comments

Comments
 (0)