Skip to content

Commit 81d7631

Browse files
authored
[MRG] More general solvers for `ot.solveand examples of different variants. (#620)
* add exaple and allow for functional regularizers * fix test since ow all is implemented * manuel regularizer available for exact and unbalanecd ot * exmaple with banaced manuel regularizer * upate documenation * pep8 * clenaup envelope instedaof implicit * big release file update
1 parent e75c9af commit 81d7631

File tree

6 files changed

+260
-31
lines changed

6 files changed

+260
-31
lines changed

RELEASES.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
11
# Releases
22

3-
## 0.9.3dev
3+
## 0.9.4dev
44

55
#### New features
6-
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
6+
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607).
7+
+ Continuous entropic mapping (PR #613)
8+
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
9+
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
710

811
#### Closed issues
9-
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
1012
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
1113
- Fix doc and example for lowrank sinkhorn (PR #601)
1214
- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534)
1315
- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610)
1416
- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611)
1517
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
1618

19+
## 0.9.3
20+
*January 2024*
21+
22+
23+
#### Closed issues
24+
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
25+
26+
1727
## 0.9.2
1828
*December 2023*
1929

examples/plot_solve_variants.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
======================================
4+
Optimal Transport solvers comparison
5+
======================================
6+
7+
This example illustrates the solutions returns for diffrent variants of exact,
8+
regularized and unbalanced OT solvers.
9+
"""
10+
11+
# Author: Remi Flamary <remi.flamary@unice.fr>
12+
#
13+
# License: MIT License
14+
# sphinx_gallery_thumbnail_number = 3
15+
16+
#%%
17+
18+
import numpy as np
19+
import matplotlib.pylab as pl
20+
import ot
21+
import ot.plot
22+
from ot.datasets import make_1D_gauss as gauss
23+
24+
##############################################################################
25+
# Generate data
26+
# -------------
27+
28+
29+
#%% parameters
30+
31+
n = 50 # nb bins
32+
33+
# bin positions
34+
x = np.arange(n, dtype=np.float64)
35+
36+
# Gaussian distributions
37+
a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std
38+
b = gauss(n, m=25, s=5)
39+
40+
# loss matrix
41+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
42+
M /= M.max()
43+
44+
45+
##############################################################################
46+
# Plot distributions and loss matrix
47+
# ----------------------------------
48+
49+
#%% plot the distributions
50+
51+
pl.figure(1, figsize=(6.4, 3))
52+
pl.plot(x, a, 'b', label='Source distribution')
53+
pl.plot(x, b, 'r', label='Target distribution')
54+
pl.legend()
55+
56+
#%% plot distributions and loss matrix
57+
58+
pl.figure(2, figsize=(5, 5))
59+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
60+
61+
##############################################################################
62+
# Define Group lasso regularization and gradient
63+
# ------------------------------------------------
64+
# The groups are the first and second half of the columns of G
65+
66+
67+
def reg_gl(G): # group lasso + small l2 reg
68+
G1 = G[:n // 2, :]**2
69+
G2 = G[n // 2:, :]**2
70+
gl1 = np.sum(np.sqrt(np.sum(G1, 0)))
71+
gl2 = np.sum(np.sqrt(np.sum(G2, 0)))
72+
return gl1 + gl2 + 0.1 * np.sum(G**2)
73+
74+
75+
def grad_gl(G): # gradient of group lasso + small l2 reg
76+
G1 = G[:n // 2, :]
77+
G2 = G[n // 2:, :]
78+
gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8)
79+
gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8)
80+
return np.concatenate((gl1, gl2), axis=0) + 0.2 * G
81+
82+
83+
reg_type_gl = (reg_gl, grad_gl)
84+
85+
# %%
86+
# Set up parameters for solvers and solve
87+
# ---------------------------------------
88+
89+
lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"]
90+
lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"]
91+
92+
lst_solvers = [ # name, param for ot.solve function
93+
# balanced OT
94+
('Exact OT', dict()),
95+
('Entropic Reg. OT', dict(reg=0.005)),
96+
('L2 Reg OT', dict(reg=1, reg_type='l2')),
97+
('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)),
98+
99+
100+
# unbalanced OT KL
101+
('Unbalanced KL No Reg.', dict(unbalanced=0.005)),
102+
('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')),
103+
('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')),
104+
('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')),
105+
106+
# unbalanced OT L2
107+
('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')),
108+
('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')),
109+
('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')),
110+
('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')),
111+
112+
# unbalanced OT TV
113+
('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')),
114+
('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')),
115+
('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')),
116+
('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')),
117+
118+
]
119+
120+
lst_plans = []
121+
for (name, param) in lst_solvers:
122+
G = ot.solve(M, a, b, **param).plan
123+
lst_plans.append(G)
124+
125+
##############################################################################
126+
# Plot plans
127+
# ----------
128+
129+
pl.figure(3, figsize=(9, 9))
130+
131+
for i, bname in enumerate(lst_unbalanced):
132+
for j, rname in enumerate(lst_regs):
133+
pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)
134+
135+
plan = lst_plans[i * len(lst_regs) + j]
136+
m2 = plan.sum(0)
137+
m1 = plan.sum(1)
138+
m1, m2 = m1 / a.max(), m2 / b.max()
139+
pl.imshow(plan, cmap='Greys')
140+
pl.plot(x, m2 * 10, 'r')
141+
pl.plot(m1 * 10, x, 'b')
142+
pl.plot(x, b / b.max() * 10, 'r', alpha=0.3)
143+
pl.plot(a / a.max() * 10, x, 'b', alpha=0.3)
144+
#pl.axis('off')
145+
pl.tick_params(left=False, right=False, labelleft=False,
146+
labelbottom=False, bottom=False)
147+
if i == 0:
148+
pl.title(rname)
149+
if j == 0:
150+
pl.ylabel(bname, fontsize=14)

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# utils functions
5959
from .utils import dist, unif, tic, toc, toq
6060

61-
__version__ = "0.9.3dev"
61+
__version__ = "0.9.4dev"
6262

6363
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
6464
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',

ot/solvers.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .gaussian import empirical_bures_wasserstein_distance
2424
from .factored import factored_optimal_transport
2525
from .lowrank import lowrank_sinkhorn
26+
from .optim import cg
2627

2728
lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']
2829

@@ -57,13 +58,15 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
5758
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
5859
OT)
5960
reg_type : str, optional
60-
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
61+
Type of regularization :math:`R` either "KL", "L2", "entropy",
62+
by default "KL". a tuple of functions can be provided for general
63+
solver (see :any:`cg`). This is only used when ``reg!=None``.
6164
unbalanced : float, optional
6265
Unbalanced penalization weight :math:`\lambda_u`, by default None
6366
(balanced OT)
6467
unbalanced_type : str, optional
6568
Type of unbalanced penalization function :math:`U` either "KL", "L2",
66-
"TV", by default "KL"
69+
"TV", by default "KL".
6770
method : str, optional
6871
Method for solving the problem when multiple algorithms are available,
6972
default None for automatic selection.
@@ -80,10 +83,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
8083
verbose : bool, optional
8184
Print information in the solver, by default False
8285
grad : str, optional
83-
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
86+
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
8487
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
8588
outputs (`plan, value, value_linear`) but with important memory cost.
86-
'implicit' provides gradients only for `value` and and other outputs are
89+
'envelope' provides gradients only for `value` and and other outputs are
8790
detached. This is useful for memory saving when only the value is needed.
8891
8992
Returns
@@ -140,13 +143,13 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
140143
# or for original Sinkhorn paper formulation [2]
141144
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
142145
143-
# Use implicit differentiation for memory saving
144-
res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
146+
# Use envelope theorem differentiation for memory saving
147+
res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors
145148
res.value.backward() # only the value is differentiable
146149
147150
Note that by default the Sinkhorn solver uses automatic differentiation to
148151
compute the gradients of the values and plan. This can be changed with the
149-
`grad` parameter. The `implicit` mode computes the implicit gradients only
152+
`grad` parameter. The `envelope` mode computes the gradients only
150153
for the value and the other outputs are detached. This is useful for
151154
memory saving when only the gradient of value is needed.
152155
@@ -311,9 +314,22 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
311314

312315
if unbalanced is None: # Balanced regularized OT
313316

314-
if reg_type.lower() in ['entropy', 'kl']:
317+
if isinstance(reg_type, tuple): # general solver
318+
319+
if max_iter is None:
320+
max_iter = 1000
321+
if tol is None:
322+
tol = 1e-9
323+
324+
plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init)
325+
326+
value_linear = nx.sum(M * plan)
327+
value = log['loss'][-1]
328+
potentials = (log['u'], log['v'])
329+
330+
elif reg_type.lower() in ['entropy', 'kl']:
315331

316-
if grad == 'implicit': # if implicit then detach the input
332+
if grad == 'envelope': # if envelope then detach the input
317333
M0, a0, b0 = M, a, b
318334
M, a, b = nx.detach(M, a, b)
319335

@@ -336,7 +352,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
336352

337353
potentials = (log['log_u'], log['log_v'])
338354

339-
if grad == 'implicit': # set the gradient at convergence
355+
if grad == 'envelope': # set the gradient at convergence
340356

341357
value = nx.set_gradients(value, (M0, a0, b0),
342358
(plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean())))
@@ -359,7 +375,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
359375

360376
else: # unbalanced AND regularized OT
361377

362-
if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
378+
if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
363379

364380
if max_iter is None:
365381
max_iter = 1000
@@ -374,14 +390,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
374390

375391
potentials = (log['logu'], log['logv'])
376392

377-
elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:
393+
elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']:
378394

379395
if max_iter is None:
380396
max_iter = 1000
381397
if tol is None:
382398
tol = 1e-12
399+
if isinstance(reg_type, str):
400+
reg_type = reg_type.lower()
383401

384-
plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
402+
plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init)
385403

386404
value_linear = nx.sum(M * plan)
387405

@@ -962,10 +980,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
962980
verbose : bool, optional
963981
Print information in the solver, by default False
964982
grad : str, optional
965-
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
983+
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
966984
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
967985
outputs (`plan, value, value_linear`) but with important memory cost.
968-
'implicit' provides gradients only for `value` and and other outputs are
986+
'envelope' provides gradients only for `value` and and other outputs are
969987
detached. This is useful for memory saving when only the value is needed.
970988
971989
Returns
@@ -1034,13 +1052,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
10341052
# lazy OT plan
10351053
lazy_plan = res.lazy_plan
10361054
1037-
# Use implicit differentiation for memory saving
1038-
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
1055+
# Use envelope theorem differentiation for memory saving
1056+
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope')
10391057
res.value.backward() # only the value is differentiable
10401058
10411059
Note that by default the Sinkhorn solver uses automatic differentiation to
10421060
compute the gradients of the values and plan. This can be changed with the
1043-
`grad` parameter. The `implicit` mode computes the implicit gradients only
1061+
`grad` parameter. The `envelope` mode computes the gradients only
10441062
for the value and the other outputs are detached. This is useful for
10451063
memory saving when only the gradient of value is needed.
10461064

0 commit comments

Comments
 (0)