From 05a68be7ea0d6aaabf9d0da0a1a8281f83fa933b Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 15 Aug 2024 17:12:13 -0300 Subject: [PATCH 1/4] Fix --- pymc_experimental/inference/smc/sampling.py | 12 ++++++++---- tests/test_blackjax_smc.py | 8 ++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 99488c85c..0afaab7e2 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -24,6 +24,10 @@ import jax import jax.numpy as jnp import numpy as np +<<<<<<< Updated upstream +======= +from blackjax.smc import extend_params +>>>>>>> Stashed changes from blackjax.smc.resampling import systematic from pymc import draw, modelcontext, to_inference_data @@ -126,17 +130,17 @@ def sample_smc_blackjax( if kernel == "HMC": mcmc_kernel = blackjax.mcmc.hmc - mcmc_parameters = dict( + mcmc_parameters = extend_params(dict( step_size=inner_kernel_params["step_size"], inverse_mass_matrix=jnp.eye(posterior_dimensions), num_integration_steps=inner_kernel_params["integration_steps"], - ) + )) elif kernel == "NUTS": mcmc_kernel = blackjax.mcmc.nuts - mcmc_parameters = dict( + mcmc_parameters = extend_params(dict( step_size=inner_kernel_params["step_size"], inverse_mass_matrix=jnp.eye(posterior_dimensions), - ) + )) else: raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'") diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index 49db7de7f..e49bf3438 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -17,6 +17,10 @@ import pytensor.tensor as pt import pytest import scipy +<<<<<<< Updated upstream:tests/test_blackjax_smc.py +======= +from blackjax.smc import extend_params +>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py from numpy import dtype from xarray.core.utils import Frozen @@ -80,7 +84,11 @@ def fast_model(): ("NUTS", False, {"step_size": 0.1}), ], ) +<<<<<<< Updated upstream:tests/test_blackjax_smc.py @pytest.mark.xfail(reason="Still need to investigate") +======= + +>>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params): """ When running the two gaussians model From 652bbde3788710edba1abb3a187edc30b2adb19f Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 15 Aug 2024 17:14:52 -0300 Subject: [PATCH 2/4] fix --- pymc_experimental/inference/smc/sampling.py | 12 ++++-------- tests/test_blackjax_smc.py | 9 +-------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 0afaab7e2..9c011841c 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -24,11 +24,7 @@ import jax import jax.numpy as jnp import numpy as np -<<<<<<< Updated upstream -======= from blackjax.smc import extend_params ->>>>>>> Stashed changes - from blackjax.smc.resampling import systematic from pymc import draw, modelcontext, to_inference_data from pymc.backends import NDArray @@ -133,14 +129,14 @@ def sample_smc_blackjax( mcmc_parameters = extend_params(dict( step_size=inner_kernel_params["step_size"], inverse_mass_matrix=jnp.eye(posterior_dimensions), - num_integration_steps=inner_kernel_params["integration_steps"], - )) + num_integration_steps=inner_kernel_params["integration_steps"]) + ) elif kernel == "NUTS": mcmc_kernel = blackjax.mcmc.nuts mcmc_parameters = extend_params(dict( step_size=inner_kernel_params["step_size"], - inverse_mass_matrix=jnp.eye(posterior_dimensions), - )) + inverse_mass_matrix=jnp.eye(posterior_dimensions)) + ) else: raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'") diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index e49bf3438..4ad769534 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -17,10 +17,6 @@ import pytensor.tensor as pt import pytest import scipy -<<<<<<< Updated upstream:tests/test_blackjax_smc.py -======= -from blackjax.smc import extend_params ->>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py from numpy import dtype from xarray.core.utils import Frozen @@ -84,11 +80,8 @@ def fast_model(): ("NUTS", False, {"step_size": 0.1}), ], ) -<<<<<<< Updated upstream:tests/test_blackjax_smc.py -@pytest.mark.xfail(reason="Still need to investigate") -======= ->>>>>>> Stashed changes:pymc_experimental/tests/test_blackjax_smc.py + def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params): """ When running the two gaussians model From 3b65bb14e11fb6e8b572ea66f5805de7d38e2110 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 15 Aug 2024 17:15:55 -0300 Subject: [PATCH 3/4] format --- tests/test_blackjax_smc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index 4ad769534..e0187bd6f 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -80,8 +80,6 @@ def fast_model(): ("NUTS", False, {"step_size": 0.1}), ], ) - - def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params): """ When running the two gaussians model From b9f6b3c0c0b571346107f1315db104557a8ae185 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 16 Aug 2024 09:02:46 -0300 Subject: [PATCH 4/4] bump blackjax in test environment --- conda-envs/environment-test.yml | 2 +- pymc_experimental/inference/smc/sampling.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 360a81991..6abe87b58 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -11,5 +11,5 @@ dependencies: - statsmodels - pip: - pymc>=5.16.1 # CI was failing to resolve - - blackjax + - blackjax>=1.2.3 - scikit-learn diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 9c011841c..898db598b 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -24,6 +24,7 @@ import jax import jax.numpy as jnp import numpy as np + from blackjax.smc import extend_params from blackjax.smc.resampling import systematic from pymc import draw, modelcontext, to_inference_data @@ -126,16 +127,20 @@ def sample_smc_blackjax( if kernel == "HMC": mcmc_kernel = blackjax.mcmc.hmc - mcmc_parameters = extend_params(dict( - step_size=inner_kernel_params["step_size"], - inverse_mass_matrix=jnp.eye(posterior_dimensions), - num_integration_steps=inner_kernel_params["integration_steps"]) + mcmc_parameters = extend_params( + dict( + step_size=inner_kernel_params["step_size"], + inverse_mass_matrix=jnp.eye(posterior_dimensions), + num_integration_steps=inner_kernel_params["integration_steps"], + ) ) elif kernel == "NUTS": mcmc_kernel = blackjax.mcmc.nuts - mcmc_parameters = extend_params(dict( - step_size=inner_kernel_params["step_size"], - inverse_mass_matrix=jnp.eye(posterior_dimensions)) + mcmc_parameters = extend_params( + dict( + step_size=inner_kernel_params["step_size"], + inverse_mass_matrix=jnp.eye(posterior_dimensions), + ) ) else: raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")