Skip to content

Commit 2941ed3

Browse files
kachayevrflamarycedricvincentcuaz
authored
[DA] Sinkhorn LpL1 transport to work on JAX (#592)
* Draft implementation for per-class regularization in lpl1 * Do not use assignment to replace non finite elements * Make vectorize version of lpl1 work * Proper lpl1 vectorization * Remove type error test for JAX (should work now) * Update test, coupling still has nans * Explicitly check for nans in the coupling return from sinkhorn * fix small comments --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
1 parent a8f0ed5 commit 2941ed3

File tree

3 files changed

+68
-28
lines changed

3 files changed

+68
-28
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
2222
- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628)
2323
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)
24+
- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592)
2425

2526
## 0.9.3
2627
*January 2024*

ot/da.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,28 +122,25 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
122122
p = 0.5
123123
epsilon = 1e-3
124124

125-
indices_labels = []
126-
classes = nx.unique(labels_a)
127-
for c in classes:
128-
idxc, = nx.where(labels_a == c)
129-
indices_labels.append(idxc)
125+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
126+
n_labels = labels_u.shape[0]
127+
unroll_labels_idx = nx.eye(n_labels, type_as=M)[labels_idx]
130128

131129
W = nx.zeros(M.shape, type_as=M)
132-
for cpt in range(numItermax):
130+
for _ in range(numItermax):
133131
Mreg = M + eta * W
134132
if log:
135133
transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
136134
stopThr=stopInnerThr, log=True)
137135
else:
138136
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
139137
stopThr=stopInnerThr)
140-
# the transport has been computed. Check if classes are really
141-
# separated
142-
W = nx.ones(M.shape, type_as=M)
143-
for (i, c) in enumerate(classes):
144-
majs = nx.sum(transp[indices_labels[i]], axis=0)
145-
majs = p * ((majs + epsilon) ** (p - 1))
146-
W[indices_labels[i]] = majs
138+
# the transport has been computed
139+
# check if classes are really separated
140+
W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :]
141+
W = nx.sum(W, axis=1)
142+
W = nx.dot(W, unroll_labels_idx.T)
143+
W = p * ((W.T + epsilon) ** (p - 1))
147144

148145
if log:
149146
return transp, log
@@ -1925,7 +1922,7 @@ def transform(self, Xs):
19251922
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
19261923

19271924
# set nans to 0
1928-
transp[~ nx.isfinite(transp)] = 0
1925+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
19291926

19301927
# compute transported samples
19311928
transp_Xs = nx.dot(transp, self.xt_)
@@ -2214,7 +2211,7 @@ class label
22142211
transp = coupling / nx.sum(coupling, 1)[:, None]
22152212

22162213
# set nans to 0
2217-
transp[~ nx.isfinite(transp)] = 0
2214+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
22182215

22192216
# compute transported samples
22202217
transp_Xs.append(nx.dot(transp, self.xt_))
@@ -2238,7 +2235,7 @@ class label
22382235
# transport the source samples
22392236
for coupling in self.coupling_:
22402237
transp = coupling / nx.sum(coupling, 1)[:, None]
2241-
transp[~ nx.isfinite(transp)] = 0
2238+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
22422239
transp_Xs_.append(nx.dot(transp, self.xt_))
22432240

22442241
transp_Xs_ = nx.concatenate(transp_Xs_, axis=0)
@@ -2291,7 +2288,7 @@ def transform_labels(self, ys=None):
22912288
transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
22922289

22932290
# set nans to 0
2294-
transp[~ nx.isfinite(transp)] = 0
2291+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
22952292

22962293
if self.log:
22972294
D1 = self.log_['D1'][i]
@@ -2339,7 +2336,7 @@ def inverse_transform_labels(self, yt=None):
23392336
transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
23402337

23412338
# set nans to 0
2342-
transp[~ nx.isfinite(transp)] = 0
2339+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
23432340

23442341
# compute propagated labels
23452342
transp_ys.append(nx.dot(D1, transp.T).T)

test/test_da.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828

2929

3030
def test_class_jax_tf():
31+
from ot.backend import tf
32+
3133
backends = []
32-
from ot.backend import jax, tf
33-
if jax:
34-
backends.append(ot.backend.JaxBackend())
3534
if tf:
3635
backends.append(ot.backend.TensorflowBackend())
3736

@@ -70,7 +69,6 @@ def test_log_da(nx, class_to_test):
7069
assert hasattr(otda, "log_")
7170

7271

73-
@pytest.skip_backend("jax")
7472
@pytest.skip_backend("tf")
7573
def test_sinkhorn_lpl1_transport_class(nx):
7674
"""test_sinkhorn_transport
@@ -79,10 +77,13 @@ def test_sinkhorn_lpl1_transport_class(nx):
7977
ns = 50
8078
nt = 50
8179

82-
Xs, ys = make_data_classif('3gauss', ns)
83-
Xt, yt = make_data_classif('3gauss2', nt)
80+
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
81+
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
82+
# prepare semi-supervised labels
83+
yt_semi = np.copy(yt)
84+
yt_semi[np.arange(0, nt, 2)] = -1
8485

85-
Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
86+
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)
8687

8788
otda = ot.da.SinkhornLpl1Transport()
8889

@@ -109,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
109110
transp_Xs = otda.transform(Xs=Xs)
110111
assert_equal(transp_Xs.shape, Xs.shape)
111112

112-
Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
113+
Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0])
113114
transp_Xs_new = otda.transform(Xs_new)
114115

115116
# check that the oos method is working
@@ -119,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
119120
transp_Xt = otda.inverse_transform(Xt=Xt)
120121
assert_equal(transp_Xt.shape, Xt.shape)
121122

122-
Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
123+
Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0])
123124
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
124125

125126
# check that the oos method is working
@@ -142,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx):
142143
# test unsupervised vs semi-supervised mode
143144
otda_unsup = ot.da.SinkhornLpl1Transport()
144145
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
146+
assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite"
145147
n_unsup = nx.sum(otda_unsup.cost_)
146148

147149
otda_semi = ot.da.SinkhornLpl1Transport()
148-
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
150+
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi)
151+
assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite"
149152
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
150153
n_semisup = nx.sum(otda_semi.cost_)
151154

@@ -944,3 +947,42 @@ def df2(G):
944947

945948
assert np.allclose(f(G), f2(G))
946949
assert np.allclose(df(G), df2(G))
950+
951+
952+
@pytest.skip_backend("jax")
953+
@pytest.skip_backend("tf")
954+
def test_sinkhorn_lpl1_vectorization(nx):
955+
n_samples, n_labels = 150, 3
956+
rng = np.random.RandomState(42)
957+
M = rng.rand(n_samples, n_samples)
958+
labels_a = rng.randint(n_labels, size=(n_samples,))
959+
M, labels_a = nx.from_numpy(M), nx.from_numpy(labels_a)
960+
961+
# hard-coded params from the original code
962+
p, epsilon = 0.5, 1e-3
963+
T = nx.from_numpy(rng.rand(n_samples, n_samples))
964+
965+
def unvectorized(transp):
966+
indices_labels = []
967+
classes = nx.unique(labels_a)
968+
for c in classes:
969+
idxc, = nx.where(labels_a == c)
970+
indices_labels.append(idxc)
971+
W = nx.ones(M.shape, type_as=M)
972+
for (i, c) in enumerate(classes):
973+
majs = nx.sum(transp[indices_labels[i]], axis=0)
974+
majs = p * ((majs + epsilon) ** (p - 1))
975+
W[indices_labels[i]] = majs
976+
return W
977+
978+
def vectorized(transp):
979+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
980+
n_labels = labels_u.shape[0]
981+
unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx]
982+
W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :]
983+
W = nx.sum(W, axis=1)
984+
W = p * ((W + epsilon) ** (p - 1))
985+
W = nx.dot(W, unroll_labels_idx.T)
986+
return W.T
987+
988+
assert np.allclose(unvectorized(T), vectorized(T))

0 commit comments

Comments
 (0)