diff --git a/RELEASES.md b/RELEASES.md index 62240fa77..a24747fb7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ - `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680) - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) +- Fix jax version for auto-grad (PR #732) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/backend.py b/ot/backend.py index d5f58bbcc..3d59639fa 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1509,7 +1509,7 @@ def set_gradients(self, val, inputs, grads): aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 aux = aux - jax.lax.stop_gradient(aux) - (val,) = jax.tree_map(lambda z: z + aux, (val,)) + (val,) = jax.tree_util.tree_map(lambda z: z + aux, (val,)) return val def _detach(self, a):