Skip to content

Commit ab12dd6

Browse files
authored
[MRG] Implement the continuous entropic mapping (#613)
* implemente pooladian mappng * update stuff * have corect normalization for entropic mapping * implement batches * improve coverage * comments cedric + pep8 * move it back
1 parent 63e44e5 commit ab12dd6

File tree

14 files changed

+235
-73
lines changed

14 files changed

+235
-73
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
355355
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.
356356

357357
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
358+
359+
[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).

benchmarks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from . import sinkhorn_knopp
33
from . import emd
44

5-
__all__= ["benchmark", "sinkhorn_knopp", "emd"]
5+
__all__ = ["benchmark", "sinkhorn_knopp", "emd"]

benchmarks/emd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def setup(n_samples):
3434
warmup_runs=warmup_runs
3535
)
3636
print(convert_to_html_table(
37-
results,
37+
results,
3838
param_name="Sample size",
3939
main_title=f"EMD - Averaged on {n_runs} runs"
4040
))

benchmarks/sinkhorn_knopp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def setup(n_samples):
3636
warmup_runs=warmup_runs
3737
)
3838
print(convert_to_html_table(
39-
results,
39+
results,
4040
param_name="Sample size",
4141
main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
4242
))

docs/nb_run_conv

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,24 @@ import subprocess
1717

1818
import os
1919

20-
cache_file='cache_nbrun'
20+
cache_file = 'cache_nbrun'
21+
22+
path_doc = 'source/auto_examples/'
23+
path_nb = '../notebooks/'
2124

22-
path_doc='source/auto_examples/'
23-
path_nb='../notebooks/'
2425

2526
def load_json(fname):
2627
try:
27-
f=open(fname)
28-
nb=json.load(f)
28+
f = open(fname)
29+
nb = json.load(f)
2930
f.close()
30-
except (OSError, IOError) :
31-
nb={}
31+
except (OSError, IOError):
32+
nb = {}
3233
return nb
3334

34-
def save_json(fname,nb):
35-
f=open(fname,'w')
35+
36+
def save_json(fname, nb):
37+
f = open(fname, 'w')
3638
f.write(json.dumps(nb))
3739
f.close()
3840

@@ -44,39 +46,36 @@ def md5(fname):
4446
hash_md5.update(chunk)
4547
return hash_md5.hexdigest()
4648

47-
def to_update(fname,cache):
49+
50+
def to_update(fname, cache):
4851
if fname in cache:
49-
if md5(path_doc+fname)==cache[fname]:
50-
res=False
52+
if md5(path_doc + fname) == cache[fname]:
53+
res = False
5154
else:
52-
res=True
55+
res = True
5356
else:
54-
res=True
55-
57+
res = True
58+
5659
return res
5760

58-
def update(fname,cache):
59-
61+
62+
def update(fname, cache):
63+
6064
# jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte
61-
subprocess.check_call(['cp',path_doc+fname,path_nb])
62-
print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']))
63-
subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])
64-
cache[fname]=md5(path_doc+fname)
65-
65+
subprocess.check_call(['cp', path_doc + fname, path_nb])
66+
print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']))
67+
subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])
68+
cache[fname] = md5(path_doc + fname)
6669

6770

68-
cache=load_json(cache_file)
71+
cache = load_json(cache_file)
6972

70-
lst_file=glob.glob(path_doc+'*.ipynb')
73+
lst_file = glob.glob(path_doc + '*.ipynb')
7174

72-
lst_file=[os.path.basename(name) for name in lst_file]
75+
lst_file = [os.path.basename(name) for name in lst_file]
7376

7477
for fname in lst_file:
75-
if to_update(fname,cache):
78+
if to_update(fname, cache):
7679
print('Updating file: {}'.format(fname))
77-
update(fname,cache)
78-
save_json(cache_file,cache)
79-
80-
81-
82-
80+
update(fname, cache)
81+
save_json(cache_file, cache)

docs/rtd/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
source_parsers = {'.md': CommonMarkParser}
44

55
source_suffix = ['.md']
6-
master_doc = 'index'
6+
master_doc = 'index'

docs/source/conf.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@
2222
print("warning sphinx-gallery not installed")
2323

2424

25-
26-
27-
28-
2925
# !!!! allow readthedoc compilation
3026
try:
3127
from unittest.mock import MagicMock
3228
except ImportError:
3329
from mock import Mock as MagicMock
34-
## check whether in the source directory...
30+
# check whether in the source directory...
3531
#
3632

3733

@@ -42,7 +38,7 @@ def __getattr__(cls, name):
4238
return MagicMock()
4339

4440

45-
MOCK_MODULES = [ 'cupy']
41+
MOCK_MODULES = ['cupy']
4642
# 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers',
4743
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
4844
# !!!!
@@ -357,12 +353,12 @@ def __getattr__(cls, name):
357353
sphinx_gallery_conf = {
358354
'examples_dirs': ['../../examples', '../../examples/da'],
359355
'gallery_dirs': 'auto_examples',
360-
'filename_pattern': 'plot_', #(?!barycenter_fgw)
361-
'nested_sections' : False,
362-
'backreferences_dir': 'gen_modules/backreferences',
363-
'inspect_global_variables' : True,
364-
'doc_module' : ('ot','numpy','scipy','pylab'),
356+
'filename_pattern': 'plot_', # (?!barycenter_fgw)
357+
'nested_sections': False,
358+
'backreferences_dir': 'gen_modules/backreferences',
359+
'inspect_global_variables': True,
360+
'doc_module': ('ot', 'numpy', 'scipy', 'pylab'),
365361
'matplotlib_animations': True,
366362
'reference_url': {
367-
'ot': None}
363+
'ot': None}
368364
}

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
6969
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
7070
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
71-
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
71+
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
7272
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7373
'binary_search_circle', 'wasserstein_circle',
7474
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',

ot/da.py

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class label
493493

494494
# pairwise distance
495495
self.cost_ = dist(Xs, Xt, metric=self.metric)
496-
self.cost_ = cost_normalization(self.cost_, self.norm)
496+
self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True)
497497

498498
if (ys is not None) and (yt is not None):
499499

@@ -1055,13 +1055,18 @@ class SinkhornTransport(BaseTransport):
10551055
The ground metric for the Wasserstein problem
10561056
norm : string, optional (default=None)
10571057
If given, normalize the ground metric to avoid numerical errors that
1058-
can occur with large metric values.
1058+
can occur with large metric values. Accepted values are 'median',
1059+
'max', 'log' and 'loglog'.
10591060
distribution_estimation : callable, optional (defaults to the uniform)
10601061
The kind of distribution estimation to employ
1061-
out_of_sample_map : string, optional (default="ferradans")
1062+
out_of_sample_map : string, optional (default="continuous")
10621063
The kind of out of sample mapping to apply to transport samples
10631064
from a domain into another one. Currently the only possible option is
1064-
"ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
1065+
"ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
1066+
<references-sinkhorntransport>` while "continuous" use the out of sample
1067+
method from :ref:`[66]
1068+
<references-sinkhorntransport>` and :ref:`[19]
1069+
<references-sinkhorntransport>`.
10651070
limit_max: float, optional (default=np.infty)
10661071
Controls the semi supervised mode. Transport between labeled source
10671072
and target samples of different classes will exhibit an cost defined
@@ -1089,13 +1094,26 @@ class SinkhornTransport(BaseTransport):
10891094
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
10901095
Regularized discrete optimal transport. SIAM Journal on Imaging
10911096
Sciences, 7(3), 1853-1882.
1097+
1098+
.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
1099+
& Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
1100+
International Conference on Learning Representation (2018)
1101+
1102+
.. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
1103+
estimation of optimal transport maps." arXiv preprint
1104+
arXiv:2109.12004 (2021).
1105+
10921106
"""
10931107

1094-
def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
1108+
def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000,
10951109
tol=10e-9, verbose=False, log=False,
10961110
metric="sqeuclidean", norm=None,
10971111
distribution_estimation=distribution_estimation_uniform,
1098-
out_of_sample_map='ferradans', limit_max=np.infty):
1112+
out_of_sample_map='continuous', limit_max=np.infty):
1113+
1114+
if out_of_sample_map not in ['ferradans', 'continuous']:
1115+
raise ValueError('Unknown out_of_sample_map method')
1116+
10991117
self.reg_e = reg_e
11001118
self.method = method
11011119
self.max_iter = max_iter
@@ -1135,6 +1153,12 @@ class label
11351153

11361154
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
11371155

1156+
if self.out_of_sample_map == 'continuous':
1157+
self.log = True
1158+
if not self.method == 'sinkhorn_log':
1159+
self.method = 'sinkhorn_log'
1160+
warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'")
1161+
11381162
# coupling estimation
11391163
returned_ = sinkhorn(
11401164
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
@@ -1150,6 +1174,120 @@ class label
11501174

11511175
return self
11521176

1177+
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
1178+
r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
1179+
1180+
Parameters
1181+
----------
1182+
Xs : array-like, shape (n_source_samples, n_features)
1183+
The source input samples.
1184+
ys : array-like, shape (n_source_samples,)
1185+
The class labels for source samples
1186+
Xt : array-like, shape (n_target_samples, n_features)
1187+
The target input samples.
1188+
yt : array-like, shape (n_target_samples,)
1189+
The class labels for target. If some target samples are unlabelled, fill the
1190+
:math:`\mathbf{y_t}`'s elements with -1.
1191+
1192+
Warning: Note that, due to this convention -1 cannot be used as a
1193+
class label
1194+
batch_size : int, optional (default=128)
1195+
The batch size for out of sample inverse transform
1196+
1197+
Returns
1198+
-------
1199+
transp_Xs : array-like, shape (n_source_samples, n_features)
1200+
The transport source samples.
1201+
"""
1202+
nx = self.nx
1203+
1204+
if self.out_of_sample_map == 'ferradans':
1205+
return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size)
1206+
1207+
else: # self.out_of_sample_map == 'continuous':
1208+
1209+
# check the necessary inputs parameters are here
1210+
g = self.log_['log_v']
1211+
1212+
indices = nx.arange(Xs.shape[0])
1213+
batch_ind = [
1214+
indices[i:i + batch_size]
1215+
for i in range(0, len(indices), batch_size)]
1216+
1217+
transp_Xs = []
1218+
for bi in batch_ind:
1219+
# get the nearest neighbor in the source domain
1220+
M = dist(Xs[bi], self.xt_, metric=self.metric)
1221+
1222+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1223+
1224+
K = nx.exp(-M / self.reg_e + g[None, :])
1225+
1226+
transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]
1227+
1228+
transp_Xs.append(transp_Xs_)
1229+
1230+
transp_Xs = nx.concatenate(transp_Xs, axis=0)
1231+
1232+
return transp_Xs
1233+
1234+
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
1235+
r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
1236+
1237+
Parameters
1238+
----------
1239+
Xs : array-like, shape (n_source_samples, n_features)
1240+
The source input samples.
1241+
ys : array-like, shape (n_source_samples,)
1242+
The class labels for source samples
1243+
Xt : array-like, shape (n_target_samples, n_features)
1244+
The target input samples.
1245+
yt : array-like, shape (n_target_samples,)
1246+
The class labels for target. If some target samples are unlabelled, fill the
1247+
:math:`\mathbf{y_t}`'s elements with -1.
1248+
1249+
Warning: Note that, due to this convention -1 cannot be used as a
1250+
class label
1251+
batch_size : int, optional (default=128)
1252+
The batch size for out of sample inverse transform
1253+
1254+
Returns
1255+
-------
1256+
transp_Xt : array-like, shape (n_source_samples, n_features)
1257+
The transport target samples.
1258+
"""
1259+
1260+
nx = self.nx
1261+
1262+
if self.out_of_sample_map == 'ferradans':
1263+
return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size)
1264+
1265+
else: # self.out_of_sample_map == 'continuous':
1266+
1267+
f = self.log_['log_u']
1268+
1269+
indices = nx.arange(Xt.shape[0])
1270+
batch_ind = [
1271+
indices[i:i + batch_size]
1272+
for i in range(0, len(indices), batch_size
1273+
)]
1274+
1275+
transp_Xt = []
1276+
for bi in batch_ind:
1277+
1278+
M = dist(Xt[bi], self.xs_, metric=self.metric)
1279+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1280+
1281+
K = nx.exp(-M / self.reg_e + f[None, :])
1282+
1283+
transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None]
1284+
1285+
transp_Xt.append(transp_Xt_)
1286+
1287+
transp_Xt = nx.concatenate(transp_Xt, axis=0)
1288+
1289+
return transp_Xt
1290+
11531291

11541292
class EMDTransport(BaseTransport):
11551293

ot/gnn/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
# All submodules and packages
1818

1919

20-
from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates)
20+
from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates)
2121

22-
from ._layers import (TFGWPooling,TWPooling)
22+
from ._layers import (TFGWPooling, TWPooling)
2323

24-
__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling']
24+
__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling']

0 commit comments

Comments
 (0)