From fa49b43f92f91ad4b22ca5ccca65c4f83ef4dca3 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 2 May 2025 20:26:15 +0800 Subject: [PATCH 1/4] Fix bug in `fit_MAP` when shared variables are used in graph --- pymc_extras/inference/find_map.py | 9 ++++++++- tests/test_find_map.py | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py index 063f6ce97..f2f0f45ad 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/find_map.py @@ -15,6 +15,7 @@ from pymc.initial_point import make_initial_point_fn from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import join_nonshared_inputs +from pymc.sampling.jax import _replace_shared_variables from pymc.util import get_default_varnames from pytensor.compile import Function from pytensor.compile.mode import Mode @@ -146,7 +147,7 @@ def _compile_grad_and_hess_to_jax( orig_loss_fn = f_loss.vm.jit_fn @jax.jit - def loss_fn_jax_grad(x, *shared): + def loss_fn_jax_grad(x): return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) f_loss_and_grad = loss_fn_jax_grad @@ -301,6 +302,12 @@ def scipy_optimize_funcs_from_loss( point=initial_point_dict, outputs=[loss], inputs=inputs ) + # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When + # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them + # away. + if use_jax_gradients: + [loss] = _replace_shared_variables([loss]) + compute_grad = use_grad and not use_jax_gradients compute_hess = use_hess and not use_jax_gradients compute_hessp = use_hessp and not use_jax_gradients diff --git a/tests/test_find_map.py b/tests/test_find_map.py index 34c8fc766..adb081eea 100644 --- a/tests/test_find_map.py +++ b/tests/test_find_map.py @@ -1,5 +1,6 @@ import numpy as np import pymc as pm +import pytensor import pytensor.tensor as pt import pytest @@ -101,3 +102,25 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie assert np.isclose(mu_hat, 3, atol=0.5) assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + + +def test_JAX_map_shared_variables(): + with pm.Model() as m: + data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data") + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data) + + optimized_point = find_MAP( + method="L-BFGS-B", + use_grad=True, + use_hess=False, + use_hessp=False, + progressbar=False, + gradient_backend="jax", + compile_kwargs={"mode": "JAX"}, + ) + mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + + assert np.isclose(mu_hat, 3, atol=0.5) + assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) From 8974377d05f270b1118a1d52c806c9291f0d0fd6 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 2 May 2025 20:34:30 +0800 Subject: [PATCH 2/4] Delay jax import --- pymc_extras/inference/find_map.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py index f2f0f45ad..f6eacaa61 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/find_map.py @@ -15,7 +15,6 @@ from pymc.initial_point import make_initial_point_fn from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import join_nonshared_inputs -from pymc.sampling.jax import _replace_shared_variables from pymc.util import get_default_varnames from pytensor.compile import Function from pytensor.compile.mode import Mode @@ -306,6 +305,8 @@ def scipy_optimize_funcs_from_loss( # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them # away. if use_jax_gradients: + from pymc.sampling.jax import _replace_shared_variables + [loss] = _replace_shared_variables([loss]) compute_grad = use_grad and not use_jax_gradients From f1bb479989856e98a9e632b5f2fc388530928648 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 3 May 2025 04:04:49 +0800 Subject: [PATCH 3/4] Try version pinning dask --- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 450b46e30..260c7b0e5 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,7 +6,7 @@ dependencies: - pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 -- dask +- dask<2025.1.1 - xhistogram - statsmodels - numba<=0.60.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index d2a3e8934..6a92aea55 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -6,7 +6,7 @@ dependencies: - pip - pytest-cov>=2.5 - pytest>=3.0 -- dask +- dask<2025.1.1 - xhistogram - statsmodels - numba<=0.60.0 From 80fa48bcba2d9003fb99a5987fd139f9e86e71c6 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 3 May 2025 05:18:20 +0800 Subject: [PATCH 4/4] Pin dask version in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ee0607473..c7bd5945e 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ extras_require = dict( - dask_histogram=["dask[complete]", "xhistogram"], + dask_histogram=["dask[complete]<2025.1.1", "xhistogram"], histogram=["xhistogram"], ) extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values())))