Skip to content

[MRG] Implement the continuous entropic mapping #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).
2 changes: 1 addition & 1 deletion benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from . import sinkhorn_knopp
from . import emd

__all__= ["benchmark", "sinkhorn_knopp", "emd"]
__all__ = ["benchmark", "sinkhorn_knopp", "emd"]
2 changes: 1 addition & 1 deletion benchmarks/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setup(n_samples):
warmup_runs=warmup_runs
)
print(convert_to_html_table(
results,
results,
param_name="Sample size",
main_title=f"EMD - Averaged on {n_runs} runs"
))
2 changes: 1 addition & 1 deletion benchmarks/sinkhorn_knopp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setup(n_samples):
warmup_runs=warmup_runs
)
print(convert_to_html_table(
results,
results,
param_name="Sample size",
main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
))
63 changes: 31 additions & 32 deletions docs/nb_run_conv
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@ import subprocess

import os

cache_file='cache_nbrun'
cache_file = 'cache_nbrun'

path_doc = 'source/auto_examples/'
path_nb = '../notebooks/'

path_doc='source/auto_examples/'
path_nb='../notebooks/'

def load_json(fname):
try:
f=open(fname)
nb=json.load(f)
f = open(fname)
nb = json.load(f)
f.close()
except (OSError, IOError) :
nb={}
except (OSError, IOError):
nb = {}
return nb

def save_json(fname,nb):
f=open(fname,'w')

def save_json(fname, nb):
f = open(fname, 'w')
f.write(json.dumps(nb))
f.close()

Expand All @@ -44,39 +46,36 @@ def md5(fname):
hash_md5.update(chunk)
return hash_md5.hexdigest()

def to_update(fname,cache):

def to_update(fname, cache):
if fname in cache:
if md5(path_doc+fname)==cache[fname]:
res=False
if md5(path_doc + fname) == cache[fname]:
res = False
else:
res=True
res = True
else:
res=True
res = True

return res

def update(fname,cache):


def update(fname, cache):

# jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte
subprocess.check_call(['cp',path_doc+fname,path_nb])
print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']))
subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])
cache[fname]=md5(path_doc+fname)

subprocess.check_call(['cp', path_doc + fname, path_nb])
print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']))
subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])
cache[fname] = md5(path_doc + fname)


cache=load_json(cache_file)
cache = load_json(cache_file)

lst_file=glob.glob(path_doc+'*.ipynb')
lst_file = glob.glob(path_doc + '*.ipynb')

lst_file=[os.path.basename(name) for name in lst_file]
lst_file = [os.path.basename(name) for name in lst_file]

for fname in lst_file:
if to_update(fname,cache):
if to_update(fname, cache):
print('Updating file: {}'.format(fname))
update(fname,cache)
save_json(cache_file,cache)




update(fname, cache)
save_json(cache_file, cache)
2 changes: 1 addition & 1 deletion docs/rtd/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
source_parsers = {'.md': CommonMarkParser}

source_suffix = ['.md']
master_doc = 'index'
master_doc = 'index'
20 changes: 8 additions & 12 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@
print("warning sphinx-gallery not installed")






# !!!! allow readthedoc compilation
try:
from unittest.mock import MagicMock
except ImportError:
from mock import Mock as MagicMock
## check whether in the source directory...
# check whether in the source directory...
#


Expand All @@ -42,7 +38,7 @@ def __getattr__(cls, name):
return MagicMock()


MOCK_MODULES = [ 'cupy']
MOCK_MODULES = ['cupy']
# 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers',
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# !!!!
Expand Down Expand Up @@ -357,12 +353,12 @@ def __getattr__(cls, name):
sphinx_gallery_conf = {
'examples_dirs': ['../../examples', '../../examples/da'],
'gallery_dirs': 'auto_examples',
'filename_pattern': 'plot_', #(?!barycenter_fgw)
'nested_sections' : False,
'backreferences_dir': 'gen_modules/backreferences',
'inspect_global_variables' : True,
'doc_module' : ('ot','numpy','scipy','pylab'),
'filename_pattern': 'plot_', # (?!barycenter_fgw)
'nested_sections': False,
'backreferences_dir': 'gen_modules/backreferences',
'inspect_global_variables': True,
'doc_module': ('ot', 'numpy', 'scipy', 'pylab'),
'matplotlib_animations': True,
'reference_url': {
'ot': None}
'ot': None}
}
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
Expand Down
150 changes: 144 additions & 6 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ class label

# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
self.cost_ = cost_normalization(self.cost_, self.norm)
self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True)

if (ys is not None) and (yt is not None):

Expand Down Expand Up @@ -1055,13 +1055,18 @@ class SinkhornTransport(BaseTransport):
The ground metric for the Wasserstein problem
norm : string, optional (default=None)
If given, normalize the ground metric to avoid numerical errors that
can occur with large metric values.
can occur with large metric values. Accepted values are 'median',
'max', 'log' and 'loglog'.
distribution_estimation : callable, optional (defaults to the uniform)
The kind of distribution estimation to employ
out_of_sample_map : string, optional (default="ferradans")
out_of_sample_map : string, optional (default="continuous")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
"ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
"ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
<references-sinkhorntransport>` while "continuous" use the out of sample
method from :ref:`[66]
<references-sinkhorntransport>` and :ref:`[19]
<references-sinkhorntransport>`.
limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an cost defined
Expand Down Expand Up @@ -1089,13 +1094,26 @@ class SinkhornTransport(BaseTransport):
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.

.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
& Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
International Conference on Learning Representation (2018)

.. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
estimation of optimal transport maps." arXiv preprint
arXiv:2109.12004 (2021).

"""

def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
out_of_sample_map='continuous', limit_max=np.infty):

if out_of_sample_map not in ['ferradans', 'continuous']:
raise ValueError('Unknown out_of_sample_map method')

self.reg_e = reg_e
self.method = method
self.max_iter = max_iter
Expand Down Expand Up @@ -1135,6 +1153,12 @@ class label

super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)

if self.out_of_sample_map == 'continuous':
self.log = True
if not self.method == 'sinkhorn_log':
self.method = 'sinkhorn_log'
warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'")

# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
Expand All @@ -1150,6 +1174,120 @@ class label

return self

def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`

Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
The source input samples.
ys : array-like, shape (n_source_samples,)
The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
The class labels for target. If some target samples are unlabelled, fill the
:math:`\mathbf{y_t}`'s elements with -1.

Warning: Note that, due to this convention -1 cannot be used as a
class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform

Returns
-------
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
nx = self.nx

if self.out_of_sample_map == 'ferradans':
return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size)

else: # self.out_of_sample_map == 'continuous':

# check the necessary inputs parameters are here
g = self.log_['log_v']

indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]

transp_Xs = []
for bi in batch_ind:
# get the nearest neighbor in the source domain
M = dist(Xs[bi], self.xt_, metric=self.metric)

M = cost_normalization(M, self.norm, value=self.norm_cost_)

K = nx.exp(-M / self.reg_e + g[None, :])

transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]

transp_Xs.append(transp_Xs_)

transp_Xs = nx.concatenate(transp_Xs, axis=0)

return transp_Xs

def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`

Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
The source input samples.
ys : array-like, shape (n_source_samples,)
The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
The class labels for target. If some target samples are unlabelled, fill the
:math:`\mathbf{y_t}`'s elements with -1.

Warning: Note that, due to this convention -1 cannot be used as a
class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform

Returns
-------
transp_Xt : array-like, shape (n_source_samples, n_features)
The transport target samples.
"""

nx = self.nx

if self.out_of_sample_map == 'ferradans':
return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size)

else: # self.out_of_sample_map == 'continuous':

f = self.log_['log_u']

indices = nx.arange(Xt.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size
)]

transp_Xt = []
for bi in batch_ind:

M = dist(Xt[bi], self.xs_, metric=self.metric)
M = cost_normalization(M, self.norm, value=self.norm_cost_)

K = nx.exp(-M / self.reg_e + f[None, :])

transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None]

transp_Xt.append(transp_Xt_)

transp_Xt = nx.concatenate(transp_Xt, axis=0)

return transp_Xt


class EMDTransport(BaseTransport):

Expand Down
6 changes: 3 additions & 3 deletions ot/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# All submodules and packages


from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates)
from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates)

from ._layers import (TFGWPooling,TWPooling)
from ._layers import (TFGWPooling, TWPooling)

__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling']
__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling']
Loading