Skip to content

Self distances with ot.bregman.empirical_sinkhorn2 higher than expected #667

Open
@Jabbath

Description

@Jabbath

I am computing matrices of W_2 distances with ot.bregman.empirical_sinkhorn2 between point clouds centered along a curve. I expect that the distance from a point cloud to itself should be zero or close to zero. However, this is not the case and the self-distances are in fact higher than to some neighboring point clouds. This seems like an unexpected behavior and I am wondering if there is an underlying issue causing it. I've created a code snippet below which highlights the issue. Any input would be greatly appreciated.

To Reproduce

import ot
import numpy as np
import matplotlib.pyplot as plt

T = 40 # Number of point clouds
N = 25 # Number of points in each cloud
max_x = 4*np.pi # Cosine is evaluate on [0, max_x]
variance = 0.4 # The variance of the normal at each point cloud
epsilon = 0.5

def generate_cos(T=100, N=50, max_x=2*np.pi, variance=1):
    """
    Generates a dataset which follows a cosine wave. Point clouds are 2-D gaussians which
    are centered at a point on the cosine wave.
    :param T: Number of point clouds to generate
    :param N: Number of samples to take at each timepoint
    :param max_x: The right end of the cosine wave
    :variance: The variance of the gaussian distributions
    :return: The data matrix of shape (T, N, 2)
    """
    # Form the matrix [[x, cos(x)], ...]
    span = np.linspace(0, max_x, T, endpoint=True)
    y_vals = np.cos(span)
    means = np.zeros((T, 2))
    means[:, 0] = span
    means[:, 1] = y_vals

    # Sample a normal dist centered at each point
    x = np.zeros((T, N, 2))
    for i, mean in enumerate(means):
        dist = np.random.multivariate_normal(mean, np.eye(2)*variance, N)
        x[i, :, :] = dist

    return x

x = generate_cos(T=T, N=N, max_x=max_x, variance=variance)

dists = np.zeros(shape=(T, T))

for i in range(T):
    for j in range(i, T):
        d = np.sqrt(ot.bregman.empirical_sinkhorn2(x[i], x[j], epsilon, 
                                                   a=ot.unif(x[i].shape[0]), 
                                                   b=ot.unif(x[j].shape[0])))
        dists[i, j] = d

dists = dists + dists.T

plt.figure(figsize=(10, 10))
plt.matshow(dists[10:20, 10:20], fignum=1)
plt.title('$W_2$ distance matrix')
plt.colorbar()
plt.show()

Screenshots

dist_mat

Expected Behavior

I expect the diagonal to be close to zero.

Environment (please complete the following information):

  • OS: Linux
  • Python version: 3.11.5
  • How was POT installed (source, pip, conda): pip
  • POT version: 0.9.4

Output of code snippet:

Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.28
Python 3.11.5 (main, Sep 22 2023, 15:34:29) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
NumPy 1.26.0
SciPy 1.11.3
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
POT 0.9.4

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions