Skip to content

Commit ea8cb64

Browse files
committed
comments cedric + pep8
1 parent 94333a0 commit ea8cb64

File tree

11 files changed

+60
-65
lines changed

11 files changed

+60
-65
lines changed

benchmarks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from . import sinkhorn_knopp
33
from . import emd
44

5-
__all__= ["benchmark", "sinkhorn_knopp", "emd"]
5+
__all__ = ["benchmark", "sinkhorn_knopp", "emd"]

benchmarks/emd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def setup(n_samples):
3434
warmup_runs=warmup_runs
3535
)
3636
print(convert_to_html_table(
37-
results,
37+
results,
3838
param_name="Sample size",
3939
main_title=f"EMD - Averaged on {n_runs} runs"
4040
))

benchmarks/sinkhorn_knopp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def setup(n_samples):
3636
warmup_runs=warmup_runs
3737
)
3838
print(convert_to_html_table(
39-
results,
39+
results,
4040
param_name="Sample size",
4141
main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
4242
))

docs/nb_run_conv

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,24 @@ import subprocess
1717

1818
import os
1919

20-
cache_file='cache_nbrun'
20+
cache_file = 'cache_nbrun'
21+
22+
path_doc = 'source/auto_examples/'
23+
path_nb = '../notebooks/'
2124

22-
path_doc='source/auto_examples/'
23-
path_nb='../notebooks/'
2425

2526
def load_json(fname):
2627
try:
27-
f=open(fname)
28-
nb=json.load(f)
28+
f = open(fname)
29+
nb = json.load(f)
2930
f.close()
30-
except (OSError, IOError) :
31-
nb={}
31+
except (OSError, IOError):
32+
nb = {}
3233
return nb
3334

34-
def save_json(fname,nb):
35-
f=open(fname,'w')
35+
36+
def save_json(fname, nb):
37+
f = open(fname, 'w')
3638
f.write(json.dumps(nb))
3739
f.close()
3840

@@ -44,39 +46,36 @@ def md5(fname):
4446
hash_md5.update(chunk)
4547
return hash_md5.hexdigest()
4648

47-
def to_update(fname,cache):
49+
50+
def to_update(fname, cache):
4851
if fname in cache:
49-
if md5(path_doc+fname)==cache[fname]:
50-
res=False
52+
if md5(path_doc + fname) == cache[fname]:
53+
res = False
5154
else:
52-
res=True
55+
res = True
5356
else:
54-
res=True
55-
57+
res = True
58+
5659
return res
5760

58-
def update(fname,cache):
59-
61+
62+
def update(fname, cache):
63+
6064
# jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte
61-
subprocess.check_call(['cp',path_doc+fname,path_nb])
62-
print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']))
63-
subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])
64-
cache[fname]=md5(path_doc+fname)
65-
65+
subprocess.check_call(['cp', path_doc + fname, path_nb])
66+
print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']))
67+
subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])
68+
cache[fname] = md5(path_doc + fname)
6669

6770

68-
cache=load_json(cache_file)
71+
cache = load_json(cache_file)
6972

70-
lst_file=glob.glob(path_doc+'*.ipynb')
73+
lst_file = glob.glob(path_doc + '*.ipynb')
7174

72-
lst_file=[os.path.basename(name) for name in lst_file]
75+
lst_file = [os.path.basename(name) for name in lst_file]
7376

7477
for fname in lst_file:
75-
if to_update(fname,cache):
78+
if to_update(fname, cache):
7679
print('Updating file: {}'.format(fname))
77-
update(fname,cache)
78-
save_json(cache_file,cache)
79-
80-
81-
82-
80+
update(fname, cache)
81+
save_json(cache_file, cache)

docs/rtd/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
source_parsers = {'.md': CommonMarkParser}
44

55
source_suffix = ['.md']
6-
master_doc = 'index'
6+
master_doc = 'index'

docs/source/conf.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@
2222
print("warning sphinx-gallery not installed")
2323

2424

25-
26-
27-
28-
2925
# !!!! allow readthedoc compilation
3026
try:
3127
from unittest.mock import MagicMock
3228
except ImportError:
3329
from mock import Mock as MagicMock
34-
## check whether in the source directory...
30+
# check whether in the source directory...
3531
#
3632

3733

@@ -42,7 +38,7 @@ def __getattr__(cls, name):
4238
return MagicMock()
4339

4440

45-
MOCK_MODULES = [ 'cupy']
41+
MOCK_MODULES = ['cupy']
4642
# 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers',
4743
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
4844
# !!!!
@@ -357,12 +353,12 @@ def __getattr__(cls, name):
357353
sphinx_gallery_conf = {
358354
'examples_dirs': ['../../examples', '../../examples/da'],
359355
'gallery_dirs': 'auto_examples',
360-
'filename_pattern': 'plot_', #(?!barycenter_fgw)
361-
'nested_sections' : False,
362-
'backreferences_dir': 'gen_modules/backreferences',
363-
'inspect_global_variables' : True,
364-
'doc_module' : ('ot','numpy','scipy','pylab'),
356+
'filename_pattern': 'plot_', # (?!barycenter_fgw)
357+
'nested_sections': False,
358+
'backreferences_dir': 'gen_modules/backreferences',
359+
'inspect_global_variables': True,
360+
'doc_module': ('ot', 'numpy', 'scipy', 'pylab'),
365361
'matplotlib_animations': True,
366362
'reference_url': {
367-
'ot': None}
363+
'ot': None}
368364
}

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
6969
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
7070
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
71-
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
71+
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
7272
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7373
'binary_search_circle', 'wasserstein_circle',
7474
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',

ot/da.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ class SinkhornTransport(BaseTransport):
10551055
The ground metric for the Wasserstein problem
10561056
norm : string, optional (default=None)
10571057
If given, normalize the ground metric to avoid numerical errors that
1058-
can occur with large metric values.
1058+
can occur with large metric values. Accepted values are 'median',
1059+
'max', 'log' and 'loglog'.
10591060
distribution_estimation : callable, optional (defaults to the uniform)
10601061
The kind of distribution estimation to employ
10611062
out_of_sample_map : string, optional (default="continuous")

ot/gnn/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
# All submodules and packages
1818

1919

20-
from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates)
20+
from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates)
2121

22-
from ._layers import (TFGWPooling,TWPooling)
22+
from ._layers import (TFGWPooling, TWPooling)
2323

24-
__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling']
24+
__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling']

ot/lp/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
# import compiled emd
2323
from .emd_wrap import emd_c, check_result, emd_1d_sorted
24-
from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
25-
binary_search_circle, wasserstein_circle,
24+
from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
25+
binary_search_circle, wasserstein_circle,
2626
semidiscrete_wasserstein2_unif_circle)
2727

2828
from ..utils import dist, list_to_array
@@ -262,7 +262,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
262262
check_marginals: bool, optional (default=True)
263263
If True, checks that the marginals mass are equal. If False, skips the
264264
check.
265-
265+
266266
267267
Returns
268268
-------
@@ -341,8 +341,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
341341
# ensure that same mass
342342
if check_marginals:
343343
np.testing.assert_almost_equal(a.sum(0),
344-
b.sum(0), err_msg='a and b vector must have the same sum',
345-
decimal=6)
344+
b.sum(0), err_msg='a and b vector must have the same sum',
345+
decimal=6)
346346
b = b * a.sum() / b.sum()
347347

348348
asel = a != 0
@@ -440,8 +440,8 @@ def emd2(a, b, M, processes=1,
440440
check_marginals: bool, optional (default=True)
441441
If True, checks that the marginals mass are equal. If False, skips the
442442
check.
443-
444-
443+
444+
445445
Returns
446446
-------
447447
W: float, array-like
@@ -506,16 +506,15 @@ def emd2(a, b, M, processes=1,
506506
b = np.asarray(b, dtype=np.float64)
507507
M = np.asarray(M, dtype=np.float64, order='C')
508508

509-
510509
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
511510
"Dimension mismatch, check dimensions of M with a and b"
512511

513512
# ensure that same mass
514513
if check_marginals:
515514
np.testing.assert_almost_equal(a.sum(0),
516-
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum',
517-
decimal=6)
518-
b = b * a.sum(0) / b.sum(0,keepdims=True)
515+
b.sum(0, keepdims=True), err_msg='a and b vector must have the same sum',
516+
decimal=6)
517+
b = b * a.sum(0) / b.sum(0, keepdims=True)
519518

520519
asel = a != 0
521520

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python
22

3+
from openmp_helpers import check_openmp_support
34
import os
45
import re
56
import subprocess
@@ -12,7 +13,6 @@
1213
from Cython.Build import cythonize
1314

1415
sys.path.append(os.path.join("ot", "helpers"))
15-
from openmp_helpers import check_openmp_support
1616

1717
# dirty but working
1818
__version__ = re.search(

0 commit comments

Comments
 (0)