Skip to content

Commit 7dc0579

Browse files
committed
Avoid changing precision in the backend
1 parent a56e1b2 commit 7dc0579

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

ot/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1941,7 +1941,7 @@ def power(self, a, exponents):
19411941
return torch.pow(a, exponents)
19421942

19431943
def norm(self, a, axis=None, keepdims=False):
1944-
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
1944+
return torch.linalg.norm(a, dim=axis, keepdims=keepdims)
19451945

19461946
def any(self, a):
19471947
return torch.any(a)

test/test_gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_empirical_bures_wasserstein_mapping_numerical_error_warning():
8181

8282

8383
def test_bures_wasserstein_distance(nx):
84-
ms, mt = np.array([0]), np.array([10])
84+
ms, mt = np.array([0]).astype(np.float32), np.array([10]).astype(np.float32)
8585
Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
8686
msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
8787
Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)

0 commit comments

Comments
 (0)