Skip to content

[DA] Sinkhorn LpL1 transport to work on JAX #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628)
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)
- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592)

## 0.9.3
*January 2024*
Expand Down
33 changes: 15 additions & 18 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,25 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
p = 0.5
epsilon = 1e-3

indices_labels = []
classes = nx.unique(labels_a)
for c in classes:
idxc, = nx.where(labels_a == c)
indices_labels.append(idxc)
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
n_labels = labels_u.shape[0]
unroll_labels_idx = nx.eye(n_labels, type_as=M)[labels_idx]

W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
for _ in range(numItermax):
Mreg = M + eta * W
if log:
transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr, log=True)
else:
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
for (i, c) in enumerate(classes):
majs = nx.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
# the transport has been computed
# check if classes are really separated
W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :]
W = nx.sum(W, axis=1)
W = nx.dot(W, unroll_labels_idx.T)
W = p * ((W.T + epsilon) ** (p - 1))

if log:
return transp, log
Expand Down Expand Up @@ -1925,7 +1922,7 @@ def transform(self, Xs):
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xs = nx.dot(transp, self.xt_)
Expand Down Expand Up @@ -2214,7 +2211,7 @@ class label
transp = coupling / nx.sum(coupling, 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xs.append(nx.dot(transp, self.xt_))
Expand All @@ -2238,7 +2235,7 @@ class label
# transport the source samples
for coupling in self.coupling_:
transp = coupling / nx.sum(coupling, 1)[:, None]
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
transp_Xs_.append(nx.dot(transp, self.xt_))

transp_Xs_ = nx.concatenate(transp_Xs_, axis=0)
Expand Down Expand Up @@ -2291,7 +2288,7 @@ def transform_labels(self, ys=None):
transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

if self.log:
D1 = self.log_['D1'][i]
Expand Down Expand Up @@ -2339,7 +2336,7 @@ def inverse_transform_labels(self, yt=None):
transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute propagated labels
transp_ys.append(nx.dot(D1, transp.T).T)
Expand Down
62 changes: 52 additions & 10 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@


def test_class_jax_tf():
from ot.backend import tf

backends = []
from ot.backend import jax, tf
if jax:
backends.append(ot.backend.JaxBackend())
if tf:
backends.append(ot.backend.TensorflowBackend())

Expand Down Expand Up @@ -70,7 +69,6 @@ def test_log_da(nx, class_to_test):
assert hasattr(otda, "log_")


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
Expand All @@ -79,10 +77,13 @@ def test_sinkhorn_lpl1_transport_class(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
# prepare semi-supervised labels
yt_semi = np.copy(yt)
yt_semi[np.arange(0, nt, 2)] = -1

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)

otda = ot.da.SinkhornLpl1Transport()

Expand All @@ -109,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)

Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0])
transp_Xs_new = otda.transform(Xs_new)

# check that the oos method is working
Expand All @@ -119,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)

Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)

# check that the oos method is working
Expand All @@ -142,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx):
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite"
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.SinkhornLpl1Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi)
assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite"
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = nx.sum(otda_semi.cost_)

Expand Down Expand Up @@ -944,3 +947,42 @@ def df2(G):

assert np.allclose(f(G), f2(G))
assert np.allclose(df(G), df2(G))


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_vectorization(nx):
n_samples, n_labels = 150, 3
rng = np.random.RandomState(42)
M = rng.rand(n_samples, n_samples)
labels_a = rng.randint(n_labels, size=(n_samples,))
M, labels_a = nx.from_numpy(M), nx.from_numpy(labels_a)

# hard-coded params from the original code
p, epsilon = 0.5, 1e-3
T = nx.from_numpy(rng.rand(n_samples, n_samples))

def unvectorized(transp):
indices_labels = []
classes = nx.unique(labels_a)
for c in classes:
idxc, = nx.where(labels_a == c)
indices_labels.append(idxc)
W = nx.ones(M.shape, type_as=M)
for (i, c) in enumerate(classes):
majs = nx.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
return W

def vectorized(transp):
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
n_labels = labels_u.shape[0]
unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx]
W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :]
W = nx.sum(W, axis=1)
W = p * ((W + epsilon) ** (p - 1))
W = nx.dot(W, unroll_labels_idx.T)
return W.T

assert np.allclose(unvectorized(T), vectorized(T))
Loading