File tree Expand file tree Collapse file tree 2 files changed +5
-16
lines changed
pymc_experimental/inference/smc Expand file tree Collapse file tree 2 files changed +5
-16
lines changed Original file line number Diff line number Diff line change 24
24
import jax
25
25
import jax .numpy as jnp
26
26
import numpy as np
27
- < << << << Updated upstream
28
- == == == =
29
27
from blackjax .smc import extend_params
30
- > >> >> >> Stashed changes
31
-
32
28
from blackjax .smc .resampling import systematic
33
29
from pymc import draw , modelcontext , to_inference_data
34
30
from pymc .backends import NDArray
@@ -133,14 +129,14 @@ def sample_smc_blackjax(
133
129
mcmc_parameters = extend_params (dict (
134
130
step_size = inner_kernel_params ["step_size" ],
135
131
inverse_mass_matrix = jnp .eye (posterior_dimensions ),
136
- num_integration_steps = inner_kernel_params ["integration_steps" ],
137
- ))
132
+ num_integration_steps = inner_kernel_params ["integration_steps" ])
133
+ )
138
134
elif kernel == "NUTS" :
139
135
mcmc_kernel = blackjax .mcmc .nuts
140
136
mcmc_parameters = extend_params (dict (
141
137
step_size = inner_kernel_params ["step_size" ],
142
- inverse_mass_matrix = jnp .eye (posterior_dimensions ),
143
- ))
138
+ inverse_mass_matrix = jnp .eye (posterior_dimensions ))
139
+ )
144
140
else :
145
141
raise ValueError (f"Invalid kernel { kernel } , valid options are 'HMC' and 'NUTS'" )
146
142
Original file line number Diff line number Diff line change 17
17
import pytensor .tensor as pt
18
18
import pytest
19
19
import scipy
20
- < << << << Updated upstream :tests / test_blackjax_smc .py
21
- == == == =
22
- from blackjax .smc import extend_params
23
- > >> >> >> Stashed changes :pymc_experimental / tests / test_blackjax_smc .py
24
20
25
21
from numpy import dtype
26
22
from xarray .core .utils import Frozen
@@ -84,11 +80,8 @@ def fast_model():
84
80
("NUTS" , False , {"step_size" : 0.1 }),
85
81
],
86
82
)
87
- < << << << Updated upstream :tests / test_blackjax_smc .py
88
- @pytest .mark .xfail (reason = "Still need to investigate" )
89
- == == == =
90
83
91
- >> >> >> > Stashed changes : pymc_experimental / tests / test_blackjax_smc . py
84
+
92
85
def test_sample_smc_blackjax (kernel , check_for_integration_steps , inner_kernel_params ):
93
86
"""
94
87
When running the two gaussians model
You can’t perform that action at this time.
0 commit comments