Skip to content

Incorrect computation of cost_correction matrix in ot.da.EMDTransport #664

Open
@martinrohbeck

Description

@martinrohbeck

Describe the bug

It seems that the cost_correction matrix is computed incorrectly. This is the current code that can be found here:

# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0 
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    cost_correction = label_match * missing_labels * self.limit_max

The issues:

  • First, the comment says that label_match is False if ys[i] != yt[i].
    However, if ys[i] != yt[i] then (ys[:, None] - yt[None, :]) != 0 will be True, hence label_match will be True - although the labels do not match (the naming is confusing in this case). Therefore, either
    • the variable should be named label_mismatch and the comment should be fixed OR
    • we check for equality label_match = (ys[:, None] - yt[None, :]) == 0 and flip the value in cost_correction, i.e. cost_correction = (1 - label_match) * ...
  • Second, cost_correction = label_match * missing_labels * self.limit_max will apply a cost correction only if missing_labels is True. However, it must not correct if missing_labels is True - hence, we need to flip it to ... * (1 - missing_labels ) * ...

Therefore, I'd propose the following change

# label_mismatch is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has True if ys[i] != yt[i]
label_mismatch = (ys[:, None] - yt[None, :]) != 0 
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    cost_correction = label_mismatch * (1 - missing_labels) * self.limit_max

Happy to send the corresponding PR if you agree.

Screenshots

The following screenshots show the effect of flipping the missing_labels value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.

image

image

Environment (please complete the following information):

Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions