Open
Description
Describe the bug
It seems that the cost_correction
matrix is computed incorrectly. This is the current code that can be found here:
# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_match * missing_labels * self.limit_max
The issues:
- First, the comment says that
label_match
is False ifys[i] != yt[i]
.
However, ifys[i] != yt[i]
then(ys[:, None] - yt[None, :]) != 0
will be True, hencelabel_match
will be True - although the labels do not match (the naming is confusing in this case). Therefore, either- the variable should be named
label_mismatch
and the comment should be fixed OR - we check for equality
label_match = (ys[:, None] - yt[None, :]) == 0
and flip the value incost_correction
, i.e.cost_correction = (1 - label_match) * ...
- the variable should be named
- Second,
cost_correction = label_match * missing_labels * self.limit_max
will apply a cost correction only ifmissing_labels
is True. However, it must not correct ifmissing_labels
is True - hence, we need to flip it to... * (1 - missing_labels ) * ...
Therefore, I'd propose the following change
# label_mismatch is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has True if ys[i] != yt[i]
label_mismatch = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_mismatch * (1 - missing_labels) * self.limit_max
Happy to send the corresponding PR if you agree.
Screenshots
The following screenshots show the effect of flipping the missing_labels
value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.
Environment (please complete the following information):
Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)