Skip to content

Commit 2ccb2aa

Browse files
committed
add test and rearrange arguments
1 parent 2ad05ab commit 2ccb2aa

File tree

2 files changed

+71
-25
lines changed

2 files changed

+71
-25
lines changed

ot/unbalanced.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from .utils import list_to_array, get_parameter_pair
2020

2121

22-
def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
23-
method='sinkhorn', numItermax=1000,
22+
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn',
23+
reg_type="entropy", warmstart=None, numItermax=1000,
2424
stopThr=1e-6, verbose=False, log=False, **kwargs):
2525
r"""
2626
Solve the unbalanced entropic regularization optimal transport problem
@@ -67,6 +67,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
6767
For semi-relaxed case, use either
6868
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
6969
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
70+
method : str
71+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
72+
'sinkhorn_reg_scaling', see those function for specific parameters
7073
reg_type : string, optional
7174
Regularizer term. Can take two values:
7275
'entropy' (negative entropy)
@@ -75,10 +78,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
7578
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
7679
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
7780
Initialization of dual potentials. If provided, the dual potentials should be given
78-
(that is the logarithm of the u,v sinkhorn scaling vectors).s
79-
method : str
80-
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
81-
'sinkhorn_reg_scaling', see those function for specific parameters
81+
(that is the logarithm of the u,v sinkhorn scaling vectors).
8282
numItermax : int, optional
8383
Max number of iterations
8484
stopThr : float, optional
@@ -165,8 +165,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
165165
raise ValueError("Unknown method '%s'." % method)
166166

167167

168-
def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
169-
method='sinkhorn', numItermax=1000,
168+
def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
169+
reg_type="entropy", warmstart=None, numItermax=1000,
170170
stopThr=1e-6, verbose=False, log=False, **kwargs):
171171
r"""
172172
Solve the entropic regularization unbalanced optimal transport problem and
@@ -212,6 +212,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None
212212
For semi-relaxed case, use either
213213
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
214214
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
215+
method : str
216+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
217+
'sinkhorn_reg_scaling', see those function for specific parameterss
215218
reg_type : string, optional
216219
Regularizer term. Can take two values:
217220
'entropy' (negative entropy)
@@ -221,9 +224,6 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None
221224
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
222225
Initialization of dual potentials. If provided, the dual potentials should be given
223226
(that is the logarithm of the u,v sinkhorn scaling vectors).
224-
method : str
225-
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
226-
'sinkhorn_reg_scaling', see those function for specific parameters
227227
numItermax : int, optional
228228
Max number of iterations
229229
stopThr : float, optional
@@ -435,12 +435,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
435435
# distances
436436
if warmstart is None:
437437
if n_hists:
438-
u = nx.ones((dim_a, 1), type_as=M) / dim_a
439-
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
438+
u = nx.ones((dim_a, 1), type_as=M)
439+
v = nx.ones((dim_b, n_hists), type_as=M)
440440
a = a.reshape(dim_a, 1)
441441
else:
442-
u = nx.ones(dim_a, type_as=M) / dim_a
443-
v = nx.ones(dim_b, type_as=M) / dim_b
442+
u = nx.ones(dim_a, type_as=M)
443+
v = nx.ones(dim_b, type_as=M)
444444
else:
445445
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
446446

@@ -644,12 +644,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
644644
# distances
645645
if warmstart is None:
646646
if n_hists:
647-
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
648-
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
647+
u = nx.ones((dim_a, n_hists), type_as=M)
648+
v = nx.ones((dim_b, n_hists), type_as=M)
649649
a = a.reshape(dim_a, 1)
650650
else:
651-
u = nx.ones(dim_a, type_as=M) / dim_a
652-
v = nx.ones(dim_b, type_as=M) / dim_b
651+
u = nx.ones(dim_a, type_as=M)
652+
v = nx.ones(dim_b, type_as=M)
653653
else:
654654
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
655655

test/test_unbalanced.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def test_unbalanced_convergence(nx, method, reg_type):
3333
reg_m = 1.
3434

3535
G, log = ot.unbalanced.sinkhorn_unbalanced(
36-
a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type,
37-
method=method, log=True, verbose=True
36+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
37+
reg_type=reg_type, log=True, verbose=True
3838
)
3939
loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
40-
a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type,
41-
method=method, verbose=True
40+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
41+
reg_type=reg_type, verbose=True
4242
))
4343
# check fixed point equations
4444
# in log-domain
@@ -70,15 +70,61 @@ def test_unbalanced_convergence(nx, method, reg_type):
7070

7171
G = ot.unbalanced.sinkhorn_unbalanced(
7272
a, b, M, reg=epsilon, reg_m=reg_m,
73-
reg_type=reg_type, method=method, verbose=True
73+
method=method, reg_type=reg_type, verbose=True
7474
)
7575
G_np = ot.unbalanced.sinkhorn_unbalanced(
7676
a_np, b_np, M_np, reg=epsilon, reg_m=reg_m,
77-
reg_type=reg_type, method=method, verbose=True
77+
method=method, reg_type=reg_type, verbose=True
7878
)
7979
np.testing.assert_allclose(G_np, nx.to_numpy(G))
8080

8181

82+
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"]))
83+
# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
84+
def test_unbalanced_warmstart(nx, method, reg_type):
85+
# test generalized sinkhorn for unbalanced OT
86+
n = 100
87+
rng = np.random.RandomState(42)
88+
89+
x = rng.randn(n, 2)
90+
a = ot.utils.unif(n)
91+
92+
# make dists unbalanced
93+
b = ot.utils.unif(n) * 1.5
94+
M = ot.dist(x, x)
95+
a, b, M = nx.from_numpy(a, b, M)
96+
97+
epsilon = 1.
98+
reg_m = 1.
99+
100+
dim_a, dim_b = M.shape
101+
warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
102+
G, log = ot.unbalanced.sinkhorn_unbalanced(
103+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
104+
reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
105+
)
106+
loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
107+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
108+
reg_type=reg_type, warmstart=warmstart, verbose=True
109+
))
110+
111+
G0, log0 = ot.unbalanced.sinkhorn_unbalanced(
112+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
113+
reg_type=reg_type, warmstart=None, log=True, verbose=True
114+
)
115+
loss0 = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
116+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
117+
reg_type=reg_type, warmstart=None, verbose=True
118+
))
119+
120+
np.testing.assert_allclose(loss, loss0, atol=1e-6)
121+
np.testing.assert_allclose(
122+
nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
123+
np.testing.assert_allclose(
124+
nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
125+
np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
126+
127+
82128
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")]))
83129
def test_unbalanced_relaxation_parameters(nx, method, reg_m):
84130
# test generalized sinkhorn for unbalanced OT

0 commit comments

Comments
 (0)