Skip to content

Commit ef6c3c1

Browse files
rflamarylaudavid
andauthored
[MRG] New API ot.solve_sample (#563)
* new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * modify low rank, remove solve_sample,OTResultLazy * solve_sample + test functions * remove low rank from branch * new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * clean ot.solve_sample and remve lazy test cause not ilplemented yet * add factored and gaussian solvers * workin lazy sinkhorn with lazy tensor returned * stuff * merge master * update documùentation * beter documentation * pep8 * big update tests * debug small test * remarques cédri * small stuff --------- Co-authored-by: laudavid <laurene.gdavid@gmail.com>
1 parent 6f4a40d commit ef6c3c1

File tree

9 files changed

+642
-28
lines changed

9 files changed

+642
-28
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
1717
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
1818
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
19+
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)
1920

2021
#### Closed issues
2122
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import solvers
3737
from . import gaussian
3838

39+
3940
# OT functions
4041
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
4142
binary_search_circle, wasserstein_circle,
@@ -50,7 +51,7 @@
5051
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
5152
from .weak import weak_optimal_transport
5253
from .factored import factored_optimal_transport
53-
from .solvers import solve, solve_gromov
54+
from .solvers import solve, solve_gromov, solve_sample
5455

5556
# utils functions
5657
from .utils import dist, unif, tic, toc, toq
@@ -65,7 +66,7 @@
6566
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
6667
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
6768
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
68-
'factored_optimal_transport', 'solve', 'solve_gromov',
69+
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
6970
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7071
'binary_search_circle', 'wasserstein_circle',
7172
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']

ot/bregman/_empirical.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,56 @@
1111

1212
import warnings
1313

14-
from ..utils import dist, list_to_array, unif
14+
from ..utils import dist, list_to_array, unif, LazyTensor
1515
from ..backend import get_backend
1616

1717
from ._sinkhorn import sinkhorn, sinkhorn2
1818

1919

20+
def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None):
21+
r""" Get a LazyTensor of Sinkhorn solution from the dual potentials
22+
23+
The returned LazyTensor is
24+
:math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`.
25+
26+
Parameters
27+
----------
28+
X_a : array-like, shape (n_samples_a, dim)
29+
samples in the source domain
30+
X_b : array-like, shape (n_samples_b, dim)
31+
samples in the target domain
32+
f : array-like, shape (n_samples_a,)
33+
First dual potentials (log space)
34+
g : array-like, shape (n_samples_b,)
35+
Second dual potentials (log space)
36+
metric : str, default='sqeuclidean'
37+
Metric used for the cost matrix computation
38+
reg : float, default=1e-1
39+
Regularization term >0
40+
nx : Backend(), default=None
41+
Numerical backend used
42+
43+
44+
Returns
45+
-------
46+
T : LazyTensor
47+
Sinkhorn solution tensor
48+
"""
49+
50+
if nx is None:
51+
nx = get_backend(X_a, X_b, f, g)
52+
53+
shape = (X_a.shape[0], X_b.shape[0])
54+
55+
def func(i, j, X_a, X_b, f, g, metric, reg):
56+
C = dist(X_a[i], X_b[j], metric=metric)
57+
return nx.exp(f[i, None] + g[None, j] - C / reg)
58+
59+
T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, metric=metric, reg=reg)
60+
61+
return T
62+
63+
2064
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
2165
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
2266
log=False, warn=True, warmstart=None, **kwargs):
@@ -198,6 +242,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
198242
if log:
199243
dict_log["u"] = f
200244
dict_log["v"] = g
245+
dict_log["niter"] = i_ot
246+
dict_log["lazy_plan"] = get_sinkhorn_lazytensor(X_s, X_t, f, g, metric, reg)
201247
return (f, g, dict_log)
202248
else:
203249
return (f, g)

ot/gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
249249
Cs12 = nx.sqrtm(Cs)
250250

251251
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
252-
W = nx.sqrt(nx.norm(ms - mt)**2 + B)
252+
W = nx.sqrt(nx.maximum(nx.norm(ms - mt)**2 + B, 0))
253253

254254
if log:
255255
log = {}

0 commit comments

Comments
 (0)