Skip to content

Wasserstein Circle distance doesn't seem correct? #738

Open
@ckp95

Description

@ckp95

Describe the bug

Forgive me if I've misunderstood what the wasserstein_circle function is supposed to do. But I would have thought that if we have

d1 = wasserstein_circle(arr1, arr2)

And then we add the same amount to both arrays (i.e. rotating both samples the same angle around the circle), then we should get the same answer.

d2 = wasserstein_circle(arr1 + delta, arr2 + delta)
assert d1 == d2

But the first example I tried fails.

To Reproduce

import numpy as np
import ot

sample1 = np.array([0.1, 0.11, 0.4, 0.6])
sample2 = np.array([0.21, 0.15, 0.7, 0.95])

d1 = ot.wasserstein_circle(sample1, sample2)

delta = 0.02

d2 = ot.wasserstein_circle(sample1 + delta, sample2 + delta)

assert d1 == d2 # fails

Expected behavior

wasserstein_circle should be rotationally symmetric, i.e. it should obey the property

d1 = wasserstein_circle(arr1, arr2)
d2 = wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert d1 == d2

For all real delta (up to floating point inaccuracies), because this amounts to just turning your head to the side.

Or am I just misunderstanding how the input is supposed to be represented here?

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): NixOS 24.11
  • Python version: 3.12.7
  • How was POT installed (source, pip, conda): Nix

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.6.67-x86_64-with-glibc2.40
Python 3.12.7 (main, Oct  1 2024, 02:05:46) [GCC 13.3.0]
NumPy 1.26.4
SciPy 1.14.1
POT 0.9.4

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions