diff --git a/RELEASES.md b/RELEASES.md index 60d574439..488366ae3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.9.3 + +#### Closed issues +- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) + + ## 0.9.2 *December 2023* diff --git a/ot/da.py b/ot/da.py index bb43623c4..4f3d3bb96 100644 --- a/ot/da.py +++ b/ot/da.py @@ -13,6 +13,7 @@ # License: MIT License import numpy as np +import warnings from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter @@ -499,12 +500,27 @@ class label if self.limit_max != np.infty: self.limit_max = self.limit_max * nx.max(self.cost_) - # zeros where source label is missing (masked with -1) - missing_labels = ys + nx.ones(ys.shape, type_as=ys) - missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1) - # zeros where labels match - label_match = ys[:, None] - yt[None, :] - self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max) + # missing_labels is a (ns, nt) matrix of {0, 1} such that + # the cells (i, j) has 0 iff either ys[i] or yt[j] is masked + missing_ys = (ys == -1) + nx.zeros(ys.shape, type_as=ys) + missing_yt = (yt == -1) + nx.zeros(yt.shape, type_as=yt) + missing_labels = missing_ys[:, None] @ missing_yt[None, :] + # labels_match is a (ns, nt) matrix of {True, False} such that + # the cells (i, j) has False if ys[i] != yt[i] + label_match = (ys[:, None] - yt[None, :]) != 0 + # cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such + # that he cells (i, j) has -Inf where there's no correction necessary + # by 'correction' we mean setting cost to a large value when + # labels do not match + # we suppress potential RuntimeWarning caused by Inf multiplication + # (as we explicitly cover potential NANs later) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + cost_correction = label_match * missing_labels * self.limit_max + # this operation is necessary because 0 * Inf = NAN + # thus is irrelevant when limit_max is finite + cost_correction = nx.nan_to_num(cost_correction, -np.infty) + self.cost_ = nx.maximum(self.cost_, cost_correction) # distribution estimation self.mu_s = self.distribution_estimation(Xs) diff --git a/test/test_da.py b/test/test_da.py index 0ef5db79e..37b709473 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx): # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert hasattr(otda, "cost_") + assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite" assert hasattr(otda, "coupling_") + assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite" # test dimensions of coupling assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples @@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples @@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples @@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx): # test dimensions of coupling assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite" + + # test coupling + assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite" # test transform transp_Xs = otda.transform(Xs=Xs) @@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) + assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite" assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" # check everything runs well with log=True otda = ot.da.SinkhornTransport(log=True) otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite" assert len(otda.log_.keys()) != 0 @@ -448,7 +457,9 @@ def test_emd_transport_class(nx): # test dimensions of coupling assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite" assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite" # test margin constraints mu_s = unif(ns) @@ -495,15 +506,22 @@ def test_emd_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.EMDTransport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) + assert_equal(otda_unsup.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite" + assert_equal(otda_unsup.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "coupling is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.EMDTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite" + assert_equal(otda_semi.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "coupling is finite" n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples