From 7053e008c7c82f919c1509dd92b10c0d1d87379f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 8 Mar 2024 14:00:27 +0100 Subject: [PATCH 1/7] implemente pooladian mappng --- README.md | 2 + ot/da.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++-- test/test_da.py | 17 ++++++++ 3 files changed, 129 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5fa4a6faa..88dce689a 100644 --- a/README.md +++ b/README.md @@ -355,3 +355,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). + +[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). diff --git a/ot/da.py b/ot/da.py index 4f3d3bb96..a35a00371 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1058,10 +1058,13 @@ class SinkhornTransport(BaseTransport): can occur with large metric values. distribution_estimation : callable, optional (defaults to the uniform) The kind of distribution estimation to employ - out_of_sample_map : string, optional (default="ferradans") + out_of_sample_map : string, optional (default="pooladian") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in :ref:`[6] `. + "ferradans" which uses the nearest neighbor method proposed in :ref:`[6] + ` while "pooladian" use the out of sample + method from :ref:`[66] + `. limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an cost defined @@ -1089,13 +1092,22 @@ class SinkhornTransport(BaseTransport): .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + .. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic + estimation of optimal transport maps." arXiv preprint + arXiv:2109.12004 (2021). + """ - def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000, + def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=np.infty): + out_of_sample_map='pooladian', limit_max=np.infty): + + if out_of_sample_map not in ['ferradans', 'pooladian']: + raise ValueError('Unknown out_of_sample_map method') + self.reg_e = reg_e self.method = method self.max_iter = max_iter @@ -1135,6 +1147,12 @@ class label super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) + if self.out_of_sample_map == 'pooladian': + self.log = True + if not self.method == 'sinkhorn_log': + self.method = 'sinkhorn_log' + warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='pooladian'") + # coupling estimation returned_ = sinkhorn( a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, @@ -1150,6 +1168,94 @@ class label return self + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The source input samples. + ys : array-like, shape (n_source_samples,) + The class labels for source samples + Xt : array-like, shape (n_target_samples, n_features) + The target input samples. + yt : array-like, shape (n_target_samples,) + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + nx = self.nx + + if self.out_of_sample_map == 'ferradans': + return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size) + + else: # self.out_of_sample_map == 'pooladian': + + # check the necessary inputs parameters are here + g = self.log_['log_v'] + + M = dist(Xs, self.xt_, metric=self.metric) + M = cost_normalization(M, self.norm) + + K = nx.exp(-M / self.reg_e + g[None, :]) + + transp_Xs = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None] + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The source input samples. + ys : array-like, shape (n_source_samples,) + The class labels for source samples + Xt : array-like, shape (n_target_samples, n_features) + The target input samples. + yt : array-like, shape (n_target_samples,) + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transport target samples. + """ + + nx = self.nx + + if self.out_of_sample_map == 'ferradans': + return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size) + + else: # self.out_of_sample_map == 'pooladian': + + f = self.log_['log_u'] + + M = dist(Xt, self.xs_, metric=self.metric) + M = cost_normalization(M, self.norm) + + K = nx.exp(-M / self.reg_e + f[None, :]) + + transp_Xt = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None] + + return transp_Xt + class EMDTransport(BaseTransport): diff --git a/test/test_da.py b/test/test_da.py index 37b709473..1f3aa824b 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -346,6 +346,23 @@ def test_sinkhorn_transport_class(nx): otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert len(otda.log_.keys()) != 0 + # test diffeernt transform and inverse transform + otda = ot.da.SinkhornTransport(out_of_sample_map='ferradans') + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + # test diffeernt transform + otda = ot.da.SinkhornTransport(out_of_sample_map='pooladian') + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + with pytest.raises(ValueError): + otda = ot.da.SinkhornTransport(out_of_sample_map='unknown') + @pytest.skip_backend("jax") @pytest.skip_backend("tf") From a7ce443f06b525d668b2e999cba914836b83b219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 8 Mar 2024 14:14:43 +0100 Subject: [PATCH 2/7] update stuff --- ot/da.py | 21 +++++++++++++-------- test/test_da.py | 13 ++++++++----- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/ot/da.py b/ot/da.py index a35a00371..5a87eaab8 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1058,12 +1058,13 @@ class SinkhornTransport(BaseTransport): can occur with large metric values. distribution_estimation : callable, optional (defaults to the uniform) The kind of distribution estimation to employ - out_of_sample_map : string, optional (default="pooladian") + out_of_sample_map : string, optional (default="continuous") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is "ferradans" which uses the nearest neighbor method proposed in :ref:`[6] - ` while "pooladian" use the out of sample + ` while "continuous" use the out of sample method from :ref:`[66] + ` and :ref:`[19] `. limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source @@ -1093,6 +1094,10 @@ class SinkhornTransport(BaseTransport): Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. + & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. + International Conference on Learning Representation (2018) + .. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic estimation of optimal transport maps." arXiv preprint arXiv:2109.12004 (2021). @@ -1103,9 +1108,9 @@ def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='pooladian', limit_max=np.infty): + out_of_sample_map='continuous', limit_max=np.infty): - if out_of_sample_map not in ['ferradans', 'pooladian']: + if out_of_sample_map not in ['ferradans', 'continuous']: raise ValueError('Unknown out_of_sample_map method') self.reg_e = reg_e @@ -1147,11 +1152,11 @@ class label super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) - if self.out_of_sample_map == 'pooladian': + if self.out_of_sample_map == 'continuous': self.log = True if not self.method == 'sinkhorn_log': self.method = 'sinkhorn_log' - warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='pooladian'") + warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'") # coupling estimation returned_ = sinkhorn( @@ -1198,7 +1203,7 @@ class label if self.out_of_sample_map == 'ferradans': return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size) - else: # self.out_of_sample_map == 'pooladian': + else: # self.out_of_sample_map == 'continuous': # check the necessary inputs parameters are here g = self.log_['log_v'] @@ -1243,7 +1248,7 @@ class label if self.out_of_sample_map == 'ferradans': return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size) - else: # self.out_of_sample_map == 'pooladian': + else: # self.out_of_sample_map == 'continuous': f = self.log_['log_u'] diff --git a/test/test_da.py b/test/test_da.py index 1f3aa824b..4ca76eed8 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -354,11 +354,14 @@ def test_sinkhorn_transport_class(nx): assert_equal(transp_Xt.shape, Xt.shape) # test diffeernt transform - otda = ot.da.SinkhornTransport(out_of_sample_map='pooladian') - transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) - assert_equal(transp_Xs.shape, Xs.shape) - transp_Xt = otda.inverse_transform(Xt=Xt) - assert_equal(transp_Xt.shape, Xt.shape) + otda = ot.da.SinkhornTransport(out_of_sample_map='continuous') + transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs2.shape, Xs.shape) + transp_Xt2 = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt2.shape, Xt.shape) + + np.testing.assert_almost_equal(nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5) + np.testing.assert_almost_equal(nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5) with pytest.raises(ValueError): otda = ot.da.SinkhornTransport(out_of_sample_map='unknown') From c356874491a509e6912f23ff4ff04e3b8fa00599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 8 Mar 2024 14:26:00 +0100 Subject: [PATCH 3/7] have corect normalization for entropic mapping --- ot/da.py | 6 +++--- ot/utils.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ot/da.py b/ot/da.py index 5a87eaab8..1ae606589 100644 --- a/ot/da.py +++ b/ot/da.py @@ -493,7 +493,7 @@ class label # pairwise distance self.cost_ = dist(Xs, Xt, metric=self.metric) - self.cost_ = cost_normalization(self.cost_, self.norm) + self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True) if (ys is not None) and (yt is not None): @@ -1209,7 +1209,7 @@ class label g = self.log_['log_v'] M = dist(Xs, self.xt_, metric=self.metric) - M = cost_normalization(M, self.norm) + M = cost_normalization(M, self.norm, value=self.norm_cost_) K = nx.exp(-M / self.reg_e + g[None, :]) @@ -1253,7 +1253,7 @@ class label f = self.log_['log_u'] M = dist(Xt, self.xs_, metric=self.metric) - M = cost_normalization(M, self.norm) + M = cost_normalization(M, self.norm, value=self.norm_cost_) K = nx.exp(-M / self.reg_e + f[None, :]) diff --git a/ot/utils.py b/ot/utils.py index 404a9f2db..04c0e550e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -360,7 +360,7 @@ def dist0(n, method='lin_square'): return res -def cost_normalization(C, norm=None): +def cost_normalization(C, norm=None, return_value=False, value=None): r""" Apply normalization to the loss matrix Parameters @@ -382,9 +382,13 @@ def cost_normalization(C, norm=None): if norm is None: pass elif norm == "median": - C /= float(nx.median(C)) + if value is None: + value = nx.median(C) + C /= value elif norm == "max": - C /= float(nx.max(C)) + if value is None: + value = nx.max(C) + C /= float(value) elif norm == "log": C = nx.log(1 + C) elif norm == "loglog": @@ -393,7 +397,10 @@ def cost_normalization(C, norm=None): raise ValueError('Norm %s is not a valid option.\n' 'Valid options are:\n' 'median, max, log, loglog' % norm) - return C + if return_value: + return C, value + else: + return C def dots(*args): From 255789cee8f2f418d419f0221cf319254409bea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 8 Mar 2024 14:39:50 +0100 Subject: [PATCH 4/7] implement batches --- ot/da.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/ot/da.py b/ot/da.py index 1ae606589..9214f0912 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1208,12 +1208,25 @@ class label # check the necessary inputs parameters are here g = self.log_['log_v'] - M = dist(Xs, self.xt_, metric=self.metric) - M = cost_normalization(M, self.norm, value=self.norm_cost_) + indices = nx.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] - K = nx.exp(-M / self.reg_e + g[None, :]) + transp_Xs = [] + for bi in batch_ind: + # get the nearest neighbor in the source domain + M = dist(Xs[bi], self.xt_, metric=self.metric) - transp_Xs = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None] + M = cost_normalization(M, self.norm, value=self.norm_cost_) + + K = nx.exp(-M / self.reg_e + g[None, :]) + + transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None] + + transp_Xs.append(transp_Xs_) + + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -1252,12 +1265,25 @@ class label f = self.log_['log_u'] - M = dist(Xt, self.xs_, metric=self.metric) - M = cost_normalization(M, self.norm, value=self.norm_cost_) + indices = nx.arange(Xt.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size + )] + + transp_Xt = [] + for bi in batch_ind: + + M = dist(Xt[bi], self.xs_, metric=self.metric) + M = cost_normalization(M, self.norm, value=self.norm_cost_) + + K = nx.exp(-M / self.reg_e + f[None, :]) + + transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None] - K = nx.exp(-M / self.reg_e + f[None, :]) + transp_Xt.append(transp_Xt_) - transp_Xt = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None] + transp_Xt = nx.concatenate(transp_Xt, axis=0) return transp_Xt From 94333a01250b5330064bc842fd70d932751b6d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 8 Mar 2024 16:03:20 +0100 Subject: [PATCH 5/7] improve coverage --- test/test_da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_da.py b/test/test_da.py index 4ca76eed8..0e51bda22 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -354,7 +354,7 @@ def test_sinkhorn_transport_class(nx): assert_equal(transp_Xt.shape, Xt.shape) # test diffeernt transform - otda = ot.da.SinkhornTransport(out_of_sample_map='continuous') + otda = ot.da.SinkhornTransport(out_of_sample_map='continuous', method='sinkhorn') transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs2.shape, Xs.shape) transp_Xt2 = otda.inverse_transform(Xt=Xt) From ea8cb64ef146810b10237fc30baa44fe56b451c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Mar 2024 16:36:10 +0100 Subject: [PATCH 6/7] comments cedric + pep8 --- benchmarks/__init__.py | 2 +- benchmarks/emd.py | 2 +- benchmarks/sinkhorn_knopp.py | 2 +- docs/nb_run_conv | 63 ++++++++++++++++++------------------ docs/rtd/conf.py | 2 +- docs/source/conf.py | 20 +++++------- ot/__init__.py | 2 +- ot/da.py | 3 +- ot/gnn/__init__.py | 6 ++-- ot/lp/__init__.py | 21 ++++++------ setup.py | 2 +- 11 files changed, 60 insertions(+), 65 deletions(-) diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 37f5e569a..9dc687db4 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -2,4 +2,4 @@ from . import sinkhorn_knopp from . import emd -__all__= ["benchmark", "sinkhorn_knopp", "emd"] +__all__ = ["benchmark", "sinkhorn_knopp", "emd"] diff --git a/benchmarks/emd.py b/benchmarks/emd.py index 9f6486300..861dab332 100644 --- a/benchmarks/emd.py +++ b/benchmarks/emd.py @@ -34,7 +34,7 @@ def setup(n_samples): warmup_runs=warmup_runs ) print(convert_to_html_table( - results, + results, param_name="Sample size", main_title=f"EMD - Averaged on {n_runs} runs" )) diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index 3a1ef3f37..ef0f22b90 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -36,7 +36,7 @@ def setup(n_samples): warmup_runs=warmup_runs ) print(convert_to_html_table( - results, + results, param_name="Sample size", main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs" )) diff --git a/docs/nb_run_conv b/docs/nb_run_conv index ad5e432d3..adb47ace0 100755 --- a/docs/nb_run_conv +++ b/docs/nb_run_conv @@ -17,22 +17,24 @@ import subprocess import os -cache_file='cache_nbrun' +cache_file = 'cache_nbrun' + +path_doc = 'source/auto_examples/' +path_nb = '../notebooks/' -path_doc='source/auto_examples/' -path_nb='../notebooks/' def load_json(fname): try: - f=open(fname) - nb=json.load(f) + f = open(fname) + nb = json.load(f) f.close() - except (OSError, IOError) : - nb={} + except (OSError, IOError): + nb = {} return nb -def save_json(fname,nb): - f=open(fname,'w') + +def save_json(fname, nb): + f = open(fname, 'w') f.write(json.dumps(nb)) f.close() @@ -44,39 +46,36 @@ def md5(fname): hash_md5.update(chunk) return hash_md5.hexdigest() -def to_update(fname,cache): + +def to_update(fname, cache): if fname in cache: - if md5(path_doc+fname)==cache[fname]: - res=False + if md5(path_doc + fname) == cache[fname]: + res = False else: - res=True + res = True else: - res=True - + res = True + return res -def update(fname,cache): - + +def update(fname, cache): + # jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte - subprocess.check_call(['cp',path_doc+fname,path_nb]) - print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])) - subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']) - cache[fname]=md5(path_doc+fname) - + subprocess.check_call(['cp', path_doc + fname, path_nb]) + print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])) + subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']) + cache[fname] = md5(path_doc + fname) -cache=load_json(cache_file) +cache = load_json(cache_file) -lst_file=glob.glob(path_doc+'*.ipynb') +lst_file = glob.glob(path_doc + '*.ipynb') -lst_file=[os.path.basename(name) for name in lst_file] +lst_file = [os.path.basename(name) for name in lst_file] for fname in lst_file: - if to_update(fname,cache): + if to_update(fname, cache): print('Updating file: {}'.format(fname)) - update(fname,cache) - save_json(cache_file,cache) - - - - + update(fname, cache) + save_json(cache_file, cache) diff --git a/docs/rtd/conf.py b/docs/rtd/conf.py index 814db75a8..cf6479bf5 100644 --- a/docs/rtd/conf.py +++ b/docs/rtd/conf.py @@ -3,4 +3,4 @@ source_parsers = {'.md': CommonMarkParser} source_suffix = ['.md'] -master_doc = 'index' \ No newline at end of file +master_doc = 'index' diff --git a/docs/source/conf.py b/docs/source/conf.py index 6452cf857..c51b96ec4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,16 +22,12 @@ print("warning sphinx-gallery not installed") - - - - # !!!! allow readthedoc compilation try: from unittest.mock import MagicMock except ImportError: from mock import Mock as MagicMock - ## check whether in the source directory... + # check whether in the source directory... # @@ -42,7 +38,7 @@ def __getattr__(cls, name): return MagicMock() -MOCK_MODULES = [ 'cupy'] +MOCK_MODULES = ['cupy'] # 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers', sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # !!!! @@ -357,12 +353,12 @@ def __getattr__(cls, name): sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../examples/da'], 'gallery_dirs': 'auto_examples', - 'filename_pattern': 'plot_', #(?!barycenter_fgw) - 'nested_sections' : False, - 'backreferences_dir': 'gen_modules/backreferences', - 'inspect_global_variables' : True, - 'doc_module' : ('ot','numpy','scipy','pylab'), + 'filename_pattern': 'plot_', # (?!barycenter_fgw) + 'nested_sections': False, + 'backreferences_dir': 'gen_modules/backreferences', + 'inspect_global_variables': True, + 'doc_module': ('ot', 'numpy', 'scipy', 'pylab'), 'matplotlib_animations': True, 'reference_url': { - 'ot': None} + 'ot': None} } diff --git a/ot/__init__.py b/ot/__init__.py index 9a63b5f6f..1c10efafd 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -68,7 +68,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', diff --git a/ot/da.py b/ot/da.py index 9214f0912..e4adaa546 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1055,7 +1055,8 @@ class SinkhornTransport(BaseTransport): The ground metric for the Wasserstein problem norm : string, optional (default=None) If given, normalize the ground metric to avoid numerical errors that - can occur with large metric values. + can occur with large metric values. Accepted values are 'median', + 'max', 'log' and 'loglog'. distribution_estimation : callable, optional (defaults to the uniform) The kind of distribution estimation to employ out_of_sample_map : string, optional (default="continuous") diff --git a/ot/gnn/__init__.py b/ot/gnn/__init__.py index 6a84100a1..af39db6d2 100644 --- a/ot/gnn/__init__.py +++ b/ot/gnn/__init__.py @@ -17,8 +17,8 @@ # All submodules and packages -from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates) +from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates) -from ._layers import (TFGWPooling,TWPooling) +from ._layers import (TFGWPooling, TWPooling) -__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling'] \ No newline at end of file +__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 93316a6c1..752c5d2d7 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -21,8 +21,8 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein_circle, +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array @@ -262,7 +262,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - + Returns ------- @@ -341,8 +341,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c # ensure that same mass if check_marginals: np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum', - decimal=6) + b.sum(0), err_msg='a and b vector must have the same sum', + decimal=6) b = b * a.sum() / b.sum() asel = a != 0 @@ -440,8 +440,8 @@ def emd2(a, b, M, processes=1, check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - - + + Returns ------- W: float, array-like @@ -506,16 +506,15 @@ def emd2(a, b, M, processes=1, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') - assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass if check_marginals: np.testing.assert_almost_equal(a.sum(0), - b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum', - decimal=6) - b = b * a.sum(0) / b.sum(0,keepdims=True) + b.sum(0, keepdims=True), err_msg='a and b vector must have the same sum', + decimal=6) + b = b * a.sum(0) / b.sum(0, keepdims=True) asel = a != 0 diff --git a/setup.py b/setup.py index 201e89c65..b80124e31 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +from openmp_helpers import check_openmp_support import os import re import subprocess @@ -12,7 +13,6 @@ from Cython.Build import cythonize sys.path.append(os.path.join("ot", "helpers")) -from openmp_helpers import check_openmp_support # dirty but working __version__ = re.search( From ebcf7960e6cd63c9dbfa4c940551b6201c320172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Mar 2024 16:47:00 +0100 Subject: [PATCH 7/7] move it back --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b80124e31..72b1488b2 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -from openmp_helpers import check_openmp_support + import os import re import subprocess @@ -13,6 +13,7 @@ from Cython.Build import cythonize sys.path.append(os.path.join("ot", "helpers")) +from openmp_helpers import check_openmp_support # dirty but working __version__ = re.search(