From 44d46149e7f12cf4f712332f70fad0a6f78a0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 23:46:02 +0200 Subject: [PATCH 1/7] merge --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 62efd4abc3ca1f7074009490130c72dc4b13319b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 25 Apr 2025 02:00:13 +0200 Subject: [PATCH 2/7] init commit --- ot/unbalanced/_lbfgs.py | 82 ++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c4de87474..89ebb1ef8 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -206,26 +206,26 @@ def lbfgsb_unbalanced( loss matrix reg: float regularization term >=0 - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative (including 0) but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. - reg_div: string, optional + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_div: string or pair of callable functions, optional (default = 'kl') Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend - regm_div: string, optional + and not tensors from the backend + regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix + G0: array-like (dim_a, dim_b), optional (default = None) + Initialization of the transport matrix. None corresponds to uniform product. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -267,26 +267,14 @@ def lbfgsb_unbalanced( ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - # wrap the callable function to handle numpy arrays - if isinstance(reg_div, tuple): - f0, df0 = reg_div - try: - f0(G0) - df0(G0) - except BaseException: - warnings.warn( - "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead" - ) - - def f(x): - return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) - - def df(x): - return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) - - reg_div = (f, df) + # test settings + regm_div = regm_div.lower() + if regm_div not in ["kl", "l2", "tv"]: + raise ValueError( + "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div) + ) - else: + if isinstance(reg_div, str): reg_div = reg_div.lower() if reg_div not in ["entropy", "kl", "l2"]: raise ValueError( @@ -295,16 +283,11 @@ def df(x): ) ) - regm_div = regm_div.lower() - if regm_div not in ["kl", "l2", "tv"]: - raise ValueError( - "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div) - ) - + # convert all inputs to numpy arrays reg_m1, reg_m2 = get_parameter_pair(reg_m) M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) + nx = get_backend(M, a, b, G0) M0 = M dim_a, dim_b = M.shape @@ -315,10 +298,33 @@ def df(x): b = nx.ones(dim_b, type_as=M) / dim_b # convert to numpy - a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg) + if nx.__name__ == "numpy": # remaining parameters which can be arrays + reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg) + else: + a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg) + G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + # wrap the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + try: + f0(G0) + df0(G0) + except BaseException: + warnings.warn( + "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead" + ) + + def f(x): + return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) + + def df(x): + return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + + reg_div = (f, df) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize( @@ -411,9 +417,9 @@ def lbfgsb_unbalanced2( Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) From 6965c58e5cfc700d9472e324621d3f9fa31c8692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 25 Apr 2025 12:20:58 +0200 Subject: [PATCH 3/7] up --- ot/unbalanced/_lbfgs.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 89ebb1ef8..716c436b9 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -46,9 +46,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -405,26 +405,26 @@ def lbfgsb_unbalanced2( loss matrix reg: float regularization term >=0 - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative (including 0) but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. - reg_div: string, optional + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_div: string or pair of callable functions, optional (default = 'kl') Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays and not tensors from the backend - regm_div: string, optional + regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix + G0: array-like (dim_a, dim_b), optional (default = None) + Initialization of the transport matrix. None corresponds to uniform product. returnCost: string, optional (default = "linear") If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. If `returnCost` = "total", then return the total unbalanced OT loss. From 9ed3b151f986eeb17f6c361b6827ee1315e53db5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 26 Apr 2025 14:50:26 +0200 Subject: [PATCH 4/7] add fun_to_numpy in utils --- ot/unbalanced/_lbfgs.py | 22 +++++----------------- ot/utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 15 +++++++++++++++ 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 716c436b9..f8478578e 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -9,12 +9,11 @@ # # License: MIT License -import warnings import numpy as np from scipy.optimize import minimize, Bounds from ..backend import get_backend -from ..utils import list_to_array, get_parameter_pair +from ..utils import list_to_array, get_parameter_pair, fun_to_numpy def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"): @@ -306,24 +305,13 @@ def lbfgsb_unbalanced( G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) - # wrap the callable function to handle numpy arrays + # potentially convert the callable function to handle numpy arrays if isinstance(reg_div, tuple): f0, df0 = reg_div - try: - f0(G0) - df0(G0) - except BaseException: - warnings.warn( - "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead" - ) - - def f(x): - return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) - - def df(x): - return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + f = fun_to_numpy(f0, G0, nx, warn=True) + df = fun_to_numpy(df0, G0, nx, warn=True) - reg_div = (f, df) + reg_div = (f, df) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) diff --git a/ot/utils.py b/ot/utils.py index 1f24fa33f..551ccf7f4 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1473,3 +1473,43 @@ def check_number_threads(numThreads): 'numThreads should either be "max" or a strictly positive integer' ) return numThreads + + +def fun_to_numpy(fun, arr, nx, warn=True): + """Convert a function to a numpy function. + + Parameters + ---------- + fun : callable + The function to convert. + arr : array-like + The input to test the function. Can be from any backend. + nx : Backend + The backend to use for the conversion. + warn : bool, optional + Whether to raise a warning if the function is not compatible with numpy. + Default is True. + Returns + ------- + fun_numpy : callable + The converted function. + """ + if arr is None: + raise ValueError("arr should not be None to test fun") + + nx_arr = get_backend(arr) + if nx_arr.__name__ != "numpy": + arr = nx.to_numpy(arr) + try: + fun(arr) + return fun + except BaseException: + if warn: + warnings.warn( + "The callable function should be able to handle numpy arrays, a compatible function is created and comes with overhead" + ) + + def fun_numpy(x): + return nx.to_numpy(fun(nx.from_numpy(x))) + + return fun_numpy diff --git a/test/test_utils.py b/test/test_utils.py index 938fd6058..917c456e4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -731,3 +731,18 @@ def test_exp_bures(nx): # exp_\Lambda(log_\Lambda(Sigma)) = Sigma Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T)) np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5) + + +def test_fun_to_numpy(nx): + arr = np.arange(5) + arrb = nx.from_numpy(arr) + + def fun(x): # backend function + return nx.sum(x) + + fun_numpy = ot.utils.fun_to_numpy(fun, arrb, nx, warn=True) + + res = nx.to_numpy(fun(arrb)) + res_np = fun_numpy(arr) + + np.testing.assert_allclose(res, res_np) From 4c78c6759fd6917eca213fbb87e30a1813f904a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 26 Apr 2025 14:53:05 +0200 Subject: [PATCH 5/7] complete tests --- test/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 917c456e4..0b2769109 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -746,3 +746,6 @@ def fun(x): # backend function res_np = fun_numpy(arr) np.testing.assert_allclose(res, res_np) + + with pytest.raises(ValueError): + ot.utils.fun_to_numpy(fun, None, nx, warn=True) From 63c9d2d717984ca641a9fe503450099410784315 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 26 Apr 2025 15:30:53 +0200 Subject: [PATCH 6/7] update releases --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index a24747fb7..d82b88b73 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) - Fix jax version for auto-grad (PR #732) +- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 9c314ec1a50e60c10aa3fe566498868e608a7f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sat, 26 Apr 2025 15:35:15 +0200 Subject: [PATCH 7/7] improve doc --- ot/unbalanced/_lbfgs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index f8478578e..ea273c7db 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -219,7 +219,8 @@ def lbfgsb_unbalanced( 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tensors from the backend + and not tensors from the backend, otherwise functions will be converted to Numpy + leading to a computational overhead. regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -407,7 +408,8 @@ def lbfgsb_unbalanced2( 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tensors from the backend + and not tensors from the backend, otherwise functions will be converted to Numpy + leading to a computational overhead. regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)