From 7dc057905fa49fd59bd908d9c10a6b3fb622e2b8 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 10 Nov 2023 13:10:37 +0100 Subject: [PATCH 1/2] Avoid changing precision in the backend --- ot/backend.py | 2 +- test/test_gaussian.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 80eb7604e..36ea51373 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1941,7 +1941,7 @@ def power(self, a, exponents): return torch.pow(a, exponents) def norm(self, a, axis=None, keepdims=False): - return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims) + return torch.linalg.norm(a, dim=axis, keepdims=keepdims) def any(self, a): return torch.any(a) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 4e3c2df7b..02b5bbe86 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -81,7 +81,7 @@ def test_empirical_bures_wasserstein_mapping_numerical_error_warning(): def test_bures_wasserstein_distance(nx): - ms, mt = np.array([0]), np.array([10]) + ms, mt = np.array([0]).astype(np.float32), np.array([10]).astype(np.float32) Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) From c6f90e304c99fa4ec2091a631b0bbb91db11c399 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 10 Nov 2023 13:41:09 +0100 Subject: [PATCH 2/2] Update RELEASES.md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index c5d63750f..974dbbf73 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -22,6 +22,7 @@ - Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) - Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559) - Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566) +- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572) ## 0.9.1 *August 2023*