Skip to content

[MRG] Fix None init plan in unbalanced lbfgs solvers #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
94 changes: 45 additions & 49 deletions ot/unbalanced/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#
# License: MIT License

import warnings
import numpy as np
from scipy.optimize import minimize, Bounds

from ..backend import get_backend
from ..utils import list_to_array, get_parameter_pair
from ..utils import list_to_array, get_parameter_pair, fun_to_numpy


def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"):
Expand Down Expand Up @@ -46,9 +45,9 @@
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
and not tensors from the backend
regm_div: string, optional
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
Expand Down Expand Up @@ -206,26 +205,27 @@
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
reg_div: string, optional
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_div: string or pair of callable functions, optional (default = 'kl')
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
regm_div: string, optional
and not tensors from the backend, otherwise functions will be converted to Numpy
leading to a computational overhead.
regm_div: string, optional (default = 'kl')
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
G0: array-like (dim_a, dim_b), optional (default = None)
Initialization of the transport matrix. None corresponds to uniform product.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Expand Down Expand Up @@ -267,26 +267,14 @@
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
"""

# wrap the callable function to handle numpy arrays
if isinstance(reg_div, tuple):
f0, df0 = reg_div
try:
f0(G0)
df0(G0)
except BaseException:
warnings.warn(
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
)

def f(x):
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))

def df(x):
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))

reg_div = (f, df)
# test settings
regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

else:
if isinstance(reg_div, str):
reg_div = reg_div.lower()
if reg_div not in ["entropy", "kl", "l2"]:
raise ValueError(
Expand All @@ -295,16 +283,11 @@
)
)

regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

# convert all inputs to numpy arrays
reg_m1, reg_m2 = get_parameter_pair(reg_m)

M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
nx = get_backend(M, a, b, G0)
M0 = M

dim_a, dim_b = M.shape
Expand All @@ -315,10 +298,22 @@
b = nx.ones(dim_b, type_as=M) / dim_b

# convert to numpy
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
if nx.__name__ == "numpy": # remaining parameters which can be arrays
reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg)
else:
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)

Check warning on line 304 in ot/unbalanced/_lbfgs.py

View check run for this annotation

Codecov / codecov/patch

ot/unbalanced/_lbfgs.py#L304

Added line #L304 was not covered by tests

G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0)
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)

# potentially convert the callable function to handle numpy arrays
if isinstance(reg_div, tuple):
f0, df0 = reg_div
f = fun_to_numpy(f0, G0, nx, warn=True)
df = fun_to_numpy(df0, G0, nx, warn=True)

reg_div = (f, df)

_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)

res = minimize(
Expand Down Expand Up @@ -399,26 +394,27 @@
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
reg_div: string, optional
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_div: string or pair of callable functions, optional (default = 'kl')
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
regm_div: string, optional
and not tensors from the backend, otherwise functions will be converted to Numpy
leading to a computational overhead.
regm_div: string, optional (default = 'kl')
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
G0: array-like (dim_a, dim_b), optional (default = None)
Initialization of the transport matrix. None corresponds to uniform product.
returnCost: string, optional (default = "linear")
If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
If `returnCost` = "total", then return the total unbalanced OT loss.
Expand Down
40 changes: 40 additions & 0 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,3 +1473,43 @@
'numThreads should either be "max" or a strictly positive integer'
)
return numThreads


def fun_to_numpy(fun, arr, nx, warn=True):
"""Convert a function to a numpy function.

Parameters
----------
fun : callable
The function to convert.
arr : array-like
The input to test the function. Can be from any backend.
nx : Backend
The backend to use for the conversion.
warn : bool, optional
Whether to raise a warning if the function is not compatible with numpy.
Default is True.
Returns
-------
fun_numpy : callable
The converted function.
"""
if arr is None:
raise ValueError("arr should not be None to test fun")

nx_arr = get_backend(arr)
if nx_arr.__name__ != "numpy":
arr = nx.to_numpy(arr)

Check warning on line 1502 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1502

Added line #L1502 was not covered by tests
try:
fun(arr)
return fun
except BaseException:
if warn:
warnings.warn(

Check warning on line 1508 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1506-L1508

Added lines #L1506 - L1508 were not covered by tests
"The callable function should be able to handle numpy arrays, a compatible function is created and comes with overhead"
)

def fun_numpy(x):
return nx.to_numpy(fun(nx.from_numpy(x)))

Check warning on line 1513 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1512-L1513

Added lines #L1512 - L1513 were not covered by tests

return fun_numpy

Check warning on line 1515 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1515

Added line #L1515 was not covered by tests
18 changes: 18 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,21 @@ def test_exp_bures(nx):
# exp_\Lambda(log_\Lambda(Sigma)) = Sigma
Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T))
np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5)


def test_fun_to_numpy(nx):
arr = np.arange(5)
arrb = nx.from_numpy(arr)

def fun(x): # backend function
return nx.sum(x)

fun_numpy = ot.utils.fun_to_numpy(fun, arrb, nx, warn=True)

res = nx.to_numpy(fun(arrb))
res_np = fun_numpy(arr)

np.testing.assert_allclose(res, res_np)

with pytest.raises(ValueError):
ot.utils.fun_to_numpy(fun, None, nx, warn=True)
Loading