diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 450b46e30..260c7b0e5 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,7 +6,7 @@ dependencies: - pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 -- dask +- dask<2025.1.1 - xhistogram - statsmodels - numba<=0.60.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index d2a3e8934..6a92aea55 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -6,7 +6,7 @@ dependencies: - pip - pytest-cov>=2.5 - pytest>=3.0 -- dask +- dask<2025.1.1 - xhistogram - statsmodels - numba<=0.60.0 diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py index 063f6ce97..f6eacaa61 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/find_map.py @@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax( orig_loss_fn = f_loss.vm.jit_fn @jax.jit - def loss_fn_jax_grad(x, *shared): + def loss_fn_jax_grad(x): return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) f_loss_and_grad = loss_fn_jax_grad @@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss( point=initial_point_dict, outputs=[loss], inputs=inputs ) + # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When + # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them + # away. + if use_jax_gradients: + from pymc.sampling.jax import _replace_shared_variables + + [loss] = _replace_shared_variables([loss]) + compute_grad = use_grad and not use_jax_gradients compute_hess = use_hess and not use_jax_gradients compute_hessp = use_hessp and not use_jax_gradients diff --git a/setup.py b/setup.py index ee0607473..c7bd5945e 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ extras_require = dict( - dask_histogram=["dask[complete]", "xhistogram"], + dask_histogram=["dask[complete]<2025.1.1", "xhistogram"], histogram=["xhistogram"], ) extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values()))) diff --git a/tests/test_find_map.py b/tests/test_find_map.py index 34c8fc766..adb081eea 100644 --- a/tests/test_find_map.py +++ b/tests/test_find_map.py @@ -1,5 +1,6 @@ import numpy as np import pymc as pm +import pytensor import pytensor.tensor as pt import pytest @@ -101,3 +102,25 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie assert np.isclose(mu_hat, 3, atol=0.5) assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + + +def test_JAX_map_shared_variables(): + with pm.Model() as m: + data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data") + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data) + + optimized_point = find_MAP( + method="L-BFGS-B", + use_grad=True, + use_hess=False, + use_hessp=False, + progressbar=False, + gradient_backend="jax", + compile_kwargs={"mode": "JAX"}, + ) + mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + + assert np.isclose(mu_hat, 3, atol=0.5) + assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)