diff --git a/RELEASES.md b/RELEASES.md index ebfa07a06..55f5b0b17 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,6 +21,7 @@ - Split `test/test_gromov.py` into `test/gromov/` (PR #619) - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) +- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) ## 0.9.3 *January 2024* diff --git a/ot/da.py b/ot/da.py index d6c55b6c2..b51b08b3a 100644 --- a/ot/da.py +++ b/ot/da.py @@ -122,14 +122,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, p = 0.5 epsilon = 1e-3 - indices_labels = [] - classes = nx.unique(labels_a) - for c in classes: - idxc, = nx.where(labels_a == c) - indices_labels.append(idxc) + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=M)[labels_idx] W = nx.zeros(M.shape, type_as=M) - for cpt in range(numItermax): + for _ in range(numItermax): Mreg = M + eta * W if log: transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, @@ -137,13 +135,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, else: transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) - # the transport has been computed. Check if classes are really - # separated - W = nx.ones(M.shape, type_as=M) - for (i, c) in enumerate(classes): - majs = nx.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon) ** (p - 1)) - W[indices_labels[i]] = majs + # the transport has been computed + # check if classes are really separated + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) + W = nx.dot(W, unroll_labels_idx.T) + W = p * ((W.T + epsilon) ** (p - 1)) if log: return transp, log @@ -1925,7 +1922,7 @@ def transform(self, Xs): transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs = nx.dot(transp, self.xt_) @@ -2214,7 +2211,7 @@ class label transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs.append(nx.dot(transp, self.xt_)) @@ -2238,7 +2235,7 @@ class label # transport the source samples for coupling in self.coupling_: transp = coupling / nx.sum(coupling, 1)[:, None] - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) transp_Xs_.append(nx.dot(transp, self.xt_)) transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) @@ -2291,7 +2288,7 @@ def transform_labels(self, ys=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) if self.log: D1 = self.log_['D1'][i] @@ -2339,7 +2336,7 @@ def inverse_transform_labels(self, yt=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels transp_ys.append(nx.dot(D1, transp.T).T) diff --git a/test/test_da.py b/test/test_da.py index 0e51bda22..d3c343242 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -28,10 +28,9 @@ def test_class_jax_tf(): + from ot.backend import tf + backends = [] - from ot.backend import jax, tf - if jax: - backends.append(ot.backend.JaxBackend()) if tf: backends.append(ot.backend.TensorflowBackend()) @@ -70,7 +69,6 @@ def test_log_da(nx, class_to_test): assert hasattr(otda, "log_") -@pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport @@ -79,10 +77,13 @@ def test_sinkhorn_lpl1_transport_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif('3gauss', ns, random_state=42) + Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + # prepare semi-supervised labels + yt_semi = np.copy(yt) + yt_semi[np.arange(0, nt, 2)] = -1 - Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornLpl1Transport() @@ -109,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -119,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -142,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) + assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) + assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite" assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) @@ -944,3 +947,42 @@ def df2(G): assert np.allclose(f(G), f2(G)) assert np.allclose(df(G), df2(G)) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_vectorization(nx): + n_samples, n_labels = 150, 3 + rng = np.random.RandomState(42) + M = rng.rand(n_samples, n_samples) + labels_a = rng.randint(n_labels, size=(n_samples,)) + M, labels_a = nx.from_numpy(M), nx.from_numpy(labels_a) + + # hard-coded params from the original code + p, epsilon = 0.5, 1e-3 + T = nx.from_numpy(rng.rand(n_samples, n_samples)) + + def unvectorized(transp): + indices_labels = [] + classes = nx.unique(labels_a) + for c in classes: + idxc, = nx.where(labels_a == c) + indices_labels.append(idxc) + W = nx.ones(M.shape, type_as=M) + for (i, c) in enumerate(classes): + majs = nx.sum(transp[indices_labels[i]], axis=0) + majs = p * ((majs + epsilon) ** (p - 1)) + W[indices_labels[i]] = majs + return W + + def vectorized(transp): + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx] + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) + W = p * ((W + epsilon) ** (p - 1)) + W = nx.dot(W, unroll_labels_idx.T) + return W.T + + assert np.allclose(unvectorized(T), vectorized(T))