Skip to content

Commit 3f9677e

Browse files
committed
debug barycenter again
1 parent 9fc4ab8 commit 3f9677e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
2323
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
24-
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)
24+
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)
2525

2626

2727
#### Closed issues

ot/gaussian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,15 +399,15 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo
399399
"""
400400
nx = get_backend(*C, *m,)
401401

402+
if weights is None:
403+
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]
404+
402405
# Compute the mean barycenter
403-
mb = nx.dot(weights, m)
406+
mb = nx.sum(m * weights[:, None], axis=0)
404407

405408
# Init the covariance barycenter
406409
Cb = nx.mean(C * weights[:, None, None], axis=0)
407410

408-
if weights is None:
409-
weights = nx.ones(len(C), type_as=C[0]) / len(C)
410-
411411
for it in range(num_iter):
412412
# fixed point update
413413
Cb12 = nx.sqrtm(Cb)

0 commit comments

Comments
 (0)