From 3bea76cf99f2c0cf7656a7ea9423a2fa105dfd2e Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 14:09:33 +0100 Subject: [PATCH 1/8] Draft implementation for per-class regularization in lpl1 --- ot/da.py | 2 +- test/test_da.py | 46 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/ot/da.py b/ot/da.py index bb43623c4..4bf2c9e32 100644 --- a/ot/da.py +++ b/ot/da.py @@ -128,7 +128,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, indices_labels.append(idxc) W = nx.zeros(M.shape, type_as=M) - for cpt in range(numItermax): + for _ in range(numItermax): Mreg = M + eta * W if log: transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, diff --git a/test/test_da.py b/test/test_da.py index 0ef5db79e..4a657b456 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -70,7 +70,6 @@ def test_log_da(nx, class_to_test): assert hasattr(otda, "log_") -@pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport @@ -81,8 +80,11 @@ def test_sinkhorn_lpl1_transport_class(nx): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + # prepare semi-supervised labels + yt_semi = np.copy(yt) + yt_semi[np.arange(0, nt, 2)] = -1 - Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornLpl1Transport() @@ -143,7 +145,7 @@ def test_sinkhorn_lpl1_transport_class(nx): n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) @@ -906,3 +908,41 @@ def df2(G): assert np.allclose(f(G), f2(G)) assert np.allclose(df(G), df2(G)) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_vectorization(nx): + n_samples, n_labels = 150, 3 + rng = np.random.RandomState(42) + M = rng.rand(n_samples, n_samples) + labels_a = rng.randint(n_labels, size=(n_samples,)) + M, labels_a = nx.from_numpy(M), nx.from_numpy(labels_a) + + # hard-coded params from the original code + p, epsilon = 0.5, 1e-3 + T = nx.from_numpy(rng.rand(n_samples, n_samples)) + + def unvectorized(transp): + indices_labels = [] + classes = nx.unique(labels_a) + for c in classes: + idxc, = nx.where(labels_a == c) + indices_labels.append(idxc) + W = nx.ones(M.shape, type_as=M) + for (i, c) in enumerate(classes): + majs = nx.sum(transp[indices_labels[i]], axis=0) + majs = p * ((majs + epsilon) ** (p - 1)) + W[indices_labels[i]] = majs + return W + + def vectorized(transp): + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx + W = nx.sum(W, axis=2).T + W = p * ((W + epsilon) ** (p - 1)) + return W + + assert np.allclose(unvectorized(T), vectorized(T)) From cd75c27a83682d7dfd5889f1f8c0ddf0fd67243c Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 14:12:06 +0100 Subject: [PATCH 2/8] Do not use assignment to replace non finite elements --- ot/da.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/da.py b/ot/da.py index 4bf2c9e32..3d05e108f 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1769,7 +1769,7 @@ def transform(self, Xs): transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs = nx.dot(transp, self.xt_) @@ -2058,7 +2058,7 @@ class label transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs.append(nx.dot(transp, self.xt_)) @@ -2082,7 +2082,7 @@ class label # transport the source samples for coupling in self.coupling_: transp = coupling / nx.sum(coupling, 1)[:, None] - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) transp_Xs_.append(nx.dot(transp, self.xt_)) transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) @@ -2135,7 +2135,7 @@ def transform_labels(self, ys=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) if self.log: D1 = self.log_['D1'][i] @@ -2183,7 +2183,7 @@ def inverse_transform_labels(self, yt=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels transp_ys.append(nx.dot(D1, transp.T).T) From 2bdcdc9d7f42cf968f0b9655c6a56ed5ad4061bc Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 15:14:26 +0100 Subject: [PATCH 3/8] Make vectorize version of lpl1 work --- test/test_da.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 4a657b456..b65760ccf 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -936,13 +936,27 @@ def unvectorized(transp): W[indices_labels[i]] = majs return W + def unvectorized_v2(transp): + indices_labels = [] + classes = nx.unique(labels_a) + for c in classes: + idxc, = nx.where(labels_a == c) + indices_labels.append(idxc) + W = nx.zeros(M.shape, type_as=M) + for (i, c) in enumerate(classes): + W[indices_labels[i]] = nx.sum(transp[indices_labels[i]], axis=0) + W = p * ((W + epsilon) ** (p - 1)) + return W-1 + + def vectorized(transp): labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) n_labels = labels_u.shape[0] - unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] - W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx - W = nx.sum(W, axis=2).T + unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx] + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) W = p * ((W + epsilon) ** (p - 1)) - return W + W = W @ unroll_labels_idx.T + return W.T assert np.allclose(unvectorized(T), vectorized(T)) From 0c412c456f8e4ee3d2216b5e900e5f4eed84d5d2 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 15:21:11 +0100 Subject: [PATCH 4/8] Proper lpl1 vectorization --- ot/da.py | 21 +++++++++------------ test/test_da.py | 13 ------------- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/ot/da.py b/ot/da.py index 3d05e108f..434ca72b5 100644 --- a/ot/da.py +++ b/ot/da.py @@ -121,11 +121,9 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, p = 0.5 epsilon = 1e-3 - indices_labels = [] - classes = nx.unique(labels_a) - for c in classes: - idxc, = nx.where(labels_a == c) - indices_labels.append(idxc) + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=M)[labels_idx] W = nx.zeros(M.shape, type_as=M) for _ in range(numItermax): @@ -136,13 +134,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, else: transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) - # the transport has been computed. Check if classes are really - # separated - W = nx.ones(M.shape, type_as=M) - for (i, c) in enumerate(classes): - majs = nx.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon) ** (p - 1)) - W[indices_labels[i]] = majs + # the transport has been computed + # check if classes are really separated + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) + W = W @ unroll_labels_idx.T + W = p * ((W.T + epsilon) ** (p - 1)) if log: return transp, log diff --git a/test/test_da.py b/test/test_da.py index b65760ccf..0a4389c55 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -936,19 +936,6 @@ def unvectorized(transp): W[indices_labels[i]] = majs return W - def unvectorized_v2(transp): - indices_labels = [] - classes = nx.unique(labels_a) - for c in classes: - idxc, = nx.where(labels_a == c) - indices_labels.append(idxc) - W = nx.zeros(M.shape, type_as=M) - for (i, c) in enumerate(classes): - W[indices_labels[i]] = nx.sum(transp[indices_labels[i]], axis=0) - W = p * ((W + epsilon) ** (p - 1)) - return W-1 - - def vectorized(transp): labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) n_labels = labels_u.shape[0] From a5ea296d963da4aa38601b6643dc731338a335b2 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 15:22:10 +0100 Subject: [PATCH 5/8] Remove type error test for JAX (should work now) --- test/test_da.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 0a4389c55..ad11c6143 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -28,10 +28,9 @@ def test_class_jax_tf(): + from ot.backend import tf + backends = [] - from ot.backend import jax, tf - if jax: - backends.append(ot.backend.JaxBackend()) if tf: backends.append(ot.backend.TensorflowBackend()) From b5079f40a8260c14703e6c35814af0121c27b720 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 15:36:12 +0100 Subject: [PATCH 6/8] Update test, coupling still has nans --- test/test_da.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index ad11c6143..a1dc93f11 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -77,8 +77,8 @@ def test_sinkhorn_lpl1_transport_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif('3gauss', ns, random_state=42) + Xt, yt = make_data_classif('3gauss2', nt, random_state=43) # prepare semi-supervised labels yt_semi = np.copy(yt) yt_semi[np.arange(0, nt, 2)] = -1 @@ -91,6 +91,8 @@ def test_sinkhorn_lpl1_transport_class(nx): otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert hasattr(otda, "cost_") assert hasattr(otda, "coupling_") + assert np.all(np.isfinite(nx.to_numpy(otda.cost_))), "cost is finite" + assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite" # test dimensions of coupling assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -108,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -118,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working From ccca6ac1b881abf561d5a398179ddfd6a52a7aa7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 22 Dec 2023 15:40:58 +0100 Subject: [PATCH 7/8] Explicitly check for nans in the coupling return from sinkhorn --- test/test_da.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_da.py b/test/test_da.py index a1dc93f11..5e8161c71 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -143,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) + assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) + assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite" assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) From ca005401215a8b605fd6ddc83a62c2e31fb48fc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 24 Jun 2024 16:28:39 +0200 Subject: [PATCH 8/8] fix small comments --- RELEASES.md | 1 + ot/da.py | 2 +- test/test_da.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index ebfa07a06..55f5b0b17 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,6 +21,7 @@ - Split `test/test_gromov.py` into `test/gromov/` (PR #619) - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) +- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) ## 0.9.3 *January 2024* diff --git a/ot/da.py b/ot/da.py index b586735ee..b51b08b3a 100644 --- a/ot/da.py +++ b/ot/da.py @@ -139,7 +139,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, # check if classes are really separated W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] W = nx.sum(W, axis=1) - W = W @ unroll_labels_idx.T + W = nx.dot(W, unroll_labels_idx.T) W = p * ((W.T + epsilon) ** (p - 1)) if log: diff --git a/test/test_da.py b/test/test_da.py index b28f7d48e..d3c343242 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -982,7 +982,7 @@ def vectorized(transp): W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] W = nx.sum(W, axis=1) W = p * ((W + epsilon) ** (p - 1)) - W = W @ unroll_labels_idx.T + W = nx.dot(W, unroll_labels_idx.T) return W.T assert np.allclose(unvectorized(T), vectorized(T))