File tree Expand file tree Collapse file tree 2 files changed +16
-4
lines changed
pymc_experimental/inference/smc Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change 24
24
import jax
25
25
import jax .numpy as jnp
26
26
import numpy as np
27
+ < << << << Updated upstream
28
+ == == == =
29
+ from blackjax .smc import extend_params
30
+ > >> >> >> Stashed changes
27
31
28
32
from blackjax .smc .resampling import systematic
29
33
from pymc import draw , modelcontext , to_inference_data
@@ -126,17 +130,17 @@ def sample_smc_blackjax(
126
130
127
131
if kernel == "HMC" :
128
132
mcmc_kernel = blackjax .mcmc .hmc
129
- mcmc_parameters = dict (
133
+ mcmc_parameters = extend_params ( dict (
130
134
step_size = inner_kernel_params ["step_size" ],
131
135
inverse_mass_matrix = jnp .eye (posterior_dimensions ),
132
136
num_integration_steps = inner_kernel_params ["integration_steps" ],
133
- )
137
+ ))
134
138
elif kernel == "NUTS" :
135
139
mcmc_kernel = blackjax .mcmc .nuts
136
- mcmc_parameters = dict (
140
+ mcmc_parameters = extend_params ( dict (
137
141
step_size = inner_kernel_params ["step_size" ],
138
142
inverse_mass_matrix = jnp .eye (posterior_dimensions ),
139
- )
143
+ ))
140
144
else :
141
145
raise ValueError (f"Invalid kernel { kernel } , valid options are 'HMC' and 'NUTS'" )
142
146
Original file line number Diff line number Diff line change 17
17
import pytensor .tensor as pt
18
18
import pytest
19
19
import scipy
20
+ < << << << Updated upstream :tests / test_blackjax_smc .py
21
+ == == == =
22
+ from blackjax .smc import extend_params
23
+ > >> >> >> Stashed changes :pymc_experimental / tests / test_blackjax_smc .py
20
24
21
25
from numpy import dtype
22
26
from xarray .core .utils import Frozen
@@ -80,7 +84,11 @@ def fast_model():
80
84
("NUTS" , False , {"step_size" : 0.1 }),
81
85
],
82
86
)
87
+ < << << << Updated upstream :tests / test_blackjax_smc .py
83
88
@pytest .mark .xfail (reason = "Still need to investigate" )
89
+ == == == =
90
+
91
+ >> >> >> > Stashed changes :pymc_experimental / tests / test_blackjax_smc .py
84
92
def test_sample_smc_blackjax (kernel , check_for_integration_steps , inner_kernel_params ):
85
93
"""
86
94
When running the two gaussians model
You can’t perform that action at this time.
0 commit comments