Skip to content

Fix DA cost correction when cost limit is set to Inf #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Releases

## 0.9.3

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)


## 0.9.2
*December 2023*

Expand Down
28 changes: 22 additions & 6 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# License: MIT License

import numpy as np
import warnings

from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
Expand Down Expand Up @@ -499,12 +500,27 @@ class label
if self.limit_max != np.infty:
self.limit_max = self.limit_max * nx.max(self.cost_)

# zeros where source label is missing (masked with -1)
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
# zeros where labels match
label_match = ys[:, None] - yt[None, :]
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)
# missing_labels is a (ns, nt) matrix of {0, 1} such that
# the cells (i, j) has 0 iff either ys[i] or yt[j] is masked
missing_ys = (ys == -1) + nx.zeros(ys.shape, type_as=ys)
missing_yt = (yt == -1) + nx.zeros(yt.shape, type_as=yt)
missing_labels = missing_ys[:, None] @ missing_yt[None, :]
# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# we suppress potential RuntimeWarning caused by Inf multiplication
# (as we explicitly cover potential NANs later)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_match * missing_labels * self.limit_max
# this operation is necessary because 0 * Inf = NAN
# thus is irrelevant when limit_max is finite
cost_correction = nx.nan_to_num(cost_correction, -np.infty)
self.cost_ = nx.maximum(self.cost_, cost_correction)

# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
Expand Down
28 changes: 23 additions & 5 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx):
# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert hasattr(otda, "coupling_")
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx):
# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"

# test coupling
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test transform
transp_Xs = otda.transform(Xs=Xs)
Expand Down Expand Up @@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx):
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check everything runs well with log=True
otda = ot.da.SinkhornTransport(log=True)
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert len(otda.log_.keys()) != 0


Expand All @@ -448,7 +457,9 @@ def test_emd_transport_class(nx):

# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test margin constraints
mu_s = unif(ns)
Expand Down Expand Up @@ -495,15 +506,22 @@ def test_emd_transport_class(nx):
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.EMDTransport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(otda_unsup.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
assert_equal(otda_unsup.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "coupling is finite"
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.EMDTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
assert_equal(otda_semi.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "coupling is finite"
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down