Skip to content

Commit 55c8d2b

Browse files
committed
fix pep 8 tests
1 parent b40705c commit 55c8d2b

File tree

5 files changed

+10
-12
lines changed

5 files changed

+10
-12
lines changed

ot/lowrank.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
import warnings
11-
from .utils import unif, list_to_array, get_lowrank_lazytensor
11+
from .utils import unif, get_lowrank_lazytensor
1212
from .backend import get_backend
1313

1414

@@ -287,4 +287,3 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto",
287287
return Q, R, g, dict_log
288288

289289
return Q, R, g
290-

ot/solvers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,27 +1248,26 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
12481248
lazy_plan = log['lazy_plan']
12491249
if not lazy0: # store plan if not lazy
12501250
plan = lazy_plan[:]
1251-
1251+
12521252
elif method == "lowrank":
12531253

12541254
if not metric.lower() in ['sqeuclidean']:
12551255
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))
1256-
1256+
12571257
if max_iter is None:
12581258
max_iter = 1000
12591259
if tol is None:
12601260
tol = 1e-9
12611261
if reg is None:
12621262
reg = 0
1263-
1263+
12641264
Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True)
12651265
value = log['value']
12661266
value_linear = log['value_linear']
12671267
lazy_plan = log['lazy_plan']
12681268
if not lazy0: # store plan if not lazy
12691269
plan = lazy_plan[:]
12701270

1271-
12721271
elif method.startswith('geomloss'): # Geomloss solver for entropi OT
12731272

12741273
split_method = method.split('_')

ot/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,7 @@ def citation(self):
11761176
}
11771177
"""
11781178

1179+
11791180
class LazyTensor(object):
11801181
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
11811182
defined by a function that computes its values on the fly from slices.
@@ -1240,4 +1241,4 @@ def __getitem__(self, key):
12401241
return self._getitem(*k, **self.kwargs)
12411242

12421243
def __repr__(self):
1243-
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
1244+
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))

test/test_solvers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
{'method': 'gaussian'},
3131
{'method': 'gaussian', 'reg': 1},
3232
{'method': 'factored', 'rank': 10},
33-
{'method': 'lowrank', 'reg':0.1}
33+
{'method': 'lowrank', 'reg': 0.1}
3434
]
3535

3636
lst_parameters_solve_sample_NotImplemented = [
3737
{'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics
3838
{'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean
39-
{'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean
40-
{"method": 'lowrank', 'metric':'euclidean'}, # fail lowrank on metric not euclidean
39+
{'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean
40+
{"method": 'lowrank', 'metric': 'euclidean'}, # fail lowrank on metric not euclidean
4141
{'lazy': True}, # fail lazy for non regularized
4242
{'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced
4343
{'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized
@@ -415,7 +415,7 @@ def test_solve_sample_methods(nx, method_params):
415415
assert_allclose_sol(sol, solb)
416416

417417
sol2 = ot.solve_sample(x, x, **method_params)
418-
if method_params['method'] not in ['factored','lowrank']:
418+
if method_params['method'] not in ['factored', 'lowrank']:
419419
np.testing.assert_allclose(sol2.value, 0)
420420

421421

test/test_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,4 +583,3 @@ def test_lowrank_LazyTensor(nx):
583583
T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx)
584584

585585
np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))
586-

0 commit comments

Comments
 (0)