diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 6f6a72737..e038b49a1 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -2,12 +2,10 @@
## Creators and Maintainers
-This toolbox has been created by
+This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/)
+and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
-* [Rémi Flamary](https://remi.flamary.com/)
-* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
-
-It is currently maintained by
+It is currently maintained by :
* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
diff --git a/README.md b/README.md
index cbc1c9e8d..8b4cca7f7 100644
--- a/README.md
+++ b/README.md
@@ -202,12 +202,9 @@ The examples folder contain several examples and use case for the library. The f
## Acknowledgements
-This toolbox has been created by
+This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/) and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
-* [Rémi Flamary](https://remi.flamary.com/)
-* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
-
-It is currently maintained by
+It is currently maintained by :
* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
@@ -218,8 +215,6 @@ POT has benefited from the financing or manpower from the following partners:



-
-
## Contributions and code of conduct
Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/master/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/master/code_of_conduct.html).
diff --git a/RELEASES.md b/RELEASES.md
index 357f9cd5a..3fb6c1f14 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -14,6 +14,7 @@
- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680)
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
+- Backend implementation of `ot.dist` for (PR #701)
#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
diff --git a/ot/utils.py b/ot/utils.py
index 431226910..1f24fa33f 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -17,25 +17,25 @@
from inspect import signature
from .backend import get_backend, Backend, NumpyBackend, JaxBackend
-__time_tic_toc = time.time()
+__time_tic_toc = time.perf_counter()
def tic():
r"""Python implementation of Matlab tic() function"""
global __time_tic_toc
- __time_tic_toc = time.time()
+ __time_tic_toc = time.perf_counter()
def toc(message="Elapsed time : {} s"):
r"""Python implementation of Matlab toc() function"""
- t = time.time()
+ t = time.perf_counter()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc
def toq():
r"""Python implementation of Julia toc() function"""
- t = time.time()
+ t = time.perf_counter()
return t - __time_tic_toc
@@ -251,7 +251,7 @@ def clean_zeros(a, b, M):
return a2, b2, M2
-def euclidean_distances(X, Y, squared=False):
+def euclidean_distances(X, Y, squared=False, nx=None):
r"""
Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
@@ -270,13 +270,13 @@ def euclidean_distances(X, Y, squared=False):
-------
distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""
-
- nx = get_backend(X, Y)
+ if nx is None:
+ nx = get_backend(X, Y)
a2 = nx.einsum("ij,ij->i", X, X)
b2 = nx.einsum("ij,ij->i", Y, Y)
- c = -2 * nx.dot(X, Y.T)
+ c = -2 * nx.dot(X, nx.transpose(Y))
c += a2[:, None]
c += b2[None, :]
@@ -291,11 +291,21 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
+def dist(
+ x1,
+ x2=None,
+ metric="sqeuclidean",
+ p=2,
+ w=None,
+ backend="auto",
+ nx=None,
+ use_tensor=False,
+):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends for the following metrics:
+ 'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.
Parameters
----------
@@ -315,7 +325,17 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
w : array-like, rank 1
Weights for the weighted metrics.
-
+ backend : str, optional
+ Backend to use for the computation. If 'auto', the backend is
+ automatically selected based on the input data. if 'scipy',
+ the ``scipy.spatial.distance.cdist`` function is used (and gradients are
+ detached).
+ use_tensor : bool, optional
+ If true use tensorized computation for the distance matrix which can
+ cause memory issues for large datasets. Default is False and the
+ parameter is used only for the 'cityblock' and 'minkowski' metrics.
+ nx : Backend, optional
+ Backend to perform computations on. If omitted, the backend defaults to that of `x1`.
Returns
-------
@@ -324,12 +344,69 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
distance matrix computed with given metric
"""
+ if nx is None:
+ nx = get_backend(x1, x2)
if x2 is None:
x2 = x1
- if metric == "sqeuclidean":
- return euclidean_distances(x1, x2, squared=True)
+ if backend == "scipy": # force scipy backend with cdist function
+ x1 = nx.to_numpy(x1)
+ x2 = nx.to_numpy(x2)
+ if isinstance(metric, str) and metric.endswith("minkowski"):
+ return nx.from_numpy(cdist(x1, x2, metric=metric, p=p, w=w))
+ if w is not None:
+ return nx.from_numpy(cdist(x1, x2, metric=metric, w=w))
+ return nx.from_numpy(cdist(x1, x2, metric=metric))
+ elif metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True, nx=nx)
elif metric == "euclidean":
- return euclidean_distances(x1, x2, squared=False)
+ return euclidean_distances(x1, x2, squared=False, nx=nx)
+ elif metric == "cityblock":
+ if use_tensor:
+ return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2)
+ else:
+ M = 0.0
+ for i in range(x1.shape[1]):
+ M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :])
+ return M
+ elif metric == "minkowski":
+ if w is None:
+ if use_tensor:
+ return nx.power(
+ nx.sum(
+ nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2
+ ),
+ 1 / p,
+ )
+ else:
+ M = 0.0
+ for i in range(x1.shape[1]):
+ M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
+ return M ** (1 / p)
+ else:
+ if use_tensor:
+ return nx.power(
+ nx.sum(
+ w[None, None, :]
+ * nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p),
+ axis=2,
+ ),
+ 1 / p,
+ )
+ else:
+ M = 0.0
+ for i in range(x1.shape[1]):
+ M += w[i] * nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
+ return M ** (1 / p)
+ elif metric == "cosine":
+ nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
+ nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
+ return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
+ elif metric == "correlation":
+ x1 = x1 - nx.mean(x1, axis=1)[:, None]
+ x2 = x2 - nx.mean(x2, axis=1)[:, None]
+ nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
+ nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
+ return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
else:
if not get_backend(x1, x2).__name__ == "numpy":
raise NotImplementedError()
diff --git a/test/test_utils.py b/test/test_utils.py
index 3f5f9ec65..938fd6058 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -8,6 +8,31 @@
import numpy as np
import sys
import pytest
+import scipy
+
+lst_metrics = [
+ "euclidean",
+ "sqeuclidean",
+ "cityblock",
+ "cosine",
+ "minkowski",
+ "correlation",
+]
+
+lst_all_metrics = lst_metrics + [
+ "braycurtis",
+ "canberra",
+ "chebyshev",
+ "dice",
+ "hamming",
+ "jaccard",
+ "matching",
+ "rogerstanimoto",
+ "russellrao",
+ "sokalmichener",
+ "sokalsneath",
+ "yule",
+]
def get_LazyTensor(nx):
@@ -185,7 +210,7 @@ def test_dist():
assert D4[0, 1] == D4[1, 0]
- # dist shoul return squared euclidean
+ # dist should return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)
@@ -229,21 +254,45 @@ def test_dist():
with pytest.raises(ValueError):
ot.dist(x, x, metric="wminkowski")
+ with pytest.raises(ValueError):
+ ot.dist(x, x, metric="fakeone")
+
-def test_dist_backends(nx):
+@pytest.mark.parametrize("metric", lst_metrics)
+def test_dist_backends(nx, metric):
n = 100
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
x1 = nx.from_numpy(x)
- lst_metric = ["euclidean", "sqeuclidean"]
+ # force numpy backend
+ D0 = ot.dist(x, x, metric=metric, backend="numpy")
+
+ # default backend
+ D = ot.dist(x, x, metric=metric)
+
+ # force nx arrays
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
+ np.testing.assert_allclose(D, D0, atol=1e-5)
+
+
+@pytest.mark.parametrize("metric", lst_all_metrics)
+def test_dist_vs_cdist(metric):
+ n = 10
+
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(n + 1, 2)
- for metric in lst_metric:
- D = ot.dist(x, x, metric=metric)
- D1 = ot.dist(x1, x1, metric=metric)
+ D = ot.dist(x, y, metric=metric)
+ Dt = ot.dist(x, y, metric=metric, use_tensor=True)
+ D2 = scipy.spatial.distance.cdist(x, y, metric=metric)
- # low atol because jax forces float32
- np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
+ np.testing.assert_allclose(D, D2, atol=1e-15)
+ np.testing.assert_allclose(D, Dt, atol=1e-15)
def test_dist0():