Skip to content

Commit f9f983c

Browse files
committed
Merge changes from main
2 parents 76288e5 + a450c83 commit f9f983c

25 files changed

+340
-210
lines changed

RELEASES.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Releases
22

3-
## 0.9.2
3+
## 0.9.2dev
44

55
#### New features
6-
- Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
6+
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
7+
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
8+
9+
#### Closed issues
710

811

912
## 0.9.1

ot/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ def _check_args_backend(backend, args):
157157
def get_backend(*args):
158158
"""Returns the proper backend for a list of input arrays
159159
160+
Accepts None entries in the arguments, and ignores them
161+
160162
Also raises TypeError if all arrays are not from the same backend
161163
"""
164+
args = [arg for arg in args if arg is not None] # exclude None entries
165+
162166
# check that some arrays given
163167
if not len(args) > 0:
164-
raise ValueError(" The function takes at least one parameter")
168+
raise ValueError(" The function takes at least one (non-None) parameter")
165169

166170
for backend in _BACKENDS:
167171
if _check_args_backend(backend, args):

ot/da.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,8 +1390,8 @@ class LinearGWTransport(LinearTransport):
13901390
References
13911391
----------
13921392
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
1393-
distances between Gaussian distributions. Journal of Applied Probability,
1394-
59(4), 1178-1198.
1393+
distances between Gaussian distributions. Journal of Applied Probability,
1394+
59(4), 1178-1198.
13951395
13961396
"""
13971397

ot/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa
155155
elif dataset.lower() == '2gauss_prop':
156156

157157
y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n))))
158-
x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2)
158+
x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * generator.randn(len(y), 2)
159159

160160
if ('bias' not in kwargs) and ('b' not in kwargs):
161161
kwargs['bias'] = np.array([0, 2])

ot/dr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pymanopt.optimizers
2626

2727
from .bregman import sinkhorn as sinkhorn_bregman
28-
from .utils import dist as dist_utils
28+
from .utils import dist as dist_utils, check_random_state
2929

3030

3131
def dist(x1, x2):
@@ -267,7 +267,7 @@ def proj(X):
267267
return Popt.point, proj
268268

269269

270-
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
270+
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0, random_state=None):
271271
r"""
272272
Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
273273
@@ -303,6 +303,9 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
303303
Stop threshold on error (>0)
304304
verbose : int, optional
305305
Print information along iterations.
306+
random_state : int, RandomState instance or None, default=None
307+
Determines random number generation for initial value of projection
308+
operator when U0 is not given.
306309
307310
Returns
308311
-------
@@ -332,7 +335,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
332335
assert d > k
333336

334337
if U0 is None:
335-
U = np.random.randn(d, k)
338+
rng = check_random_state(random_state)
339+
U = rng.randn(d, k)
336340
U, _ = np.linalg.qr(U)
337341
else:
338342
U = U0

ot/gromov/_bregman.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
8181
q : array-like, shape (nt,), optional
8282
Distribution in the target space.
8383
If let to its default value None, uniform distribution is taken.
84-
loss_fun : string, optional
84+
loss_fun : string, optional (default='square_loss')
8585
Loss function used for the solver either 'square_loss' or 'kl_loss'
8686
epsilon : float, optional
8787
Regularization term >0
@@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
9292
G0: array-like, shape (ns,nt), optional
9393
If None the initial transport plan of the solver is pq^T.
9494
Otherwise G0 will be used as initial transport of the solver. G0 is not
95-
required to satisfy marginal constraints but we strongly recommand it
96-
to correcly estimate the GW distance.
95+
required to satisfy marginal constraints but we strongly recommend it
96+
to correctly estimate the GW distance.
9797
max_iter : int, optional
9898
Max number of iterations
9999
tol : float, optional
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135135
if solver not in ['PGD', 'PPA']:
136136
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
137137

138+
if loss_fun not in ('square_loss', 'kl_loss'):
139+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
140+
138141
C1, C2 = list_to_array(C1, C2)
139142
arr = [C1, C2]
140143
if p is not None:
@@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
280283
q : array-like, shape (nt,), optional
281284
Distribution in the target space.
282285
If let to its default value None, uniform distribution is taken.
283-
loss_fun : string, optional
286+
loss_fun : string, optional (default='square_loss')
284287
Loss function used for the solver either 'square_loss' or 'kl_loss'
285288
epsilon : float, optional
286289
Regularization term >0
@@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
373376
lambdas : list of float, optional
374377
List of the `S` spaces' weights.
375378
If let to its default value None, uniform weights are taken.
376-
loss_fun : callable, optional
377-
tensor-matrix multiplication function based on specific loss function
379+
loss_fun : string, optional (default='square_loss')
380+
Loss function used for the solver either 'square_loss' or 'kl_loss'
378381
epsilon : float, optional
379382
Regularization term >0
380383
symmetric : bool, optional.
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411414
"Gromov-Wasserstein averaging of kernel and distance matrices."
412415
International Conference on Machine Learning (ICML). 2016.
413416
"""
417+
if loss_fun not in ('square_loss', 'kl_loss'):
418+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
419+
414420
Cs = list_to_array(*Cs)
415421
arr = [*Cs]
416422
if ps is not None:
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459465

460466
if loss_fun == 'square_loss':
461467
C = update_square_loss(p, lambdas, T, Cs)
462-
463468
elif loss_fun == 'kl_loss':
464469
C = update_kl_loss(p, lambdas, T, Cs)
465470

@@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
550555
q : array-like, shape (nt,), optional
551556
Distribution in the target space.
552557
If let to its default value None, uniform distribution is taken.
553-
loss_fun : string, optional
558+
loss_fun : string, optional (default='square_loss')
554559
Loss function used for the solver either 'square_loss' or 'kl_loss'
555560
epsilon : float, optional
556561
Regularization term >0
557562
symmetric : bool, optional
558563
Either C1 and C2 are to be assumed symmetric or not.
559564
If let to its default None value, a symmetry test will be conducted.
560-
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
565+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
561566
alpha : float, optional
562567
Trade-off parameter (0 < alpha < 1)
563568
G0: array-like, shape (ns,nt), optional
564569
If None the initial transport plan of the solver is pq^T.
565570
Otherwise G0 will be used as initial transport of the solver. G0 is not
566-
required to satisfy marginal constraints but we strongly recommand it
567-
to correcly estimate the GW distance.
571+
required to satisfy marginal constraints but we strongly recommend it
572+
to correctly estimate the GW distance.
568573
max_iter : int, optional
569574
Max number of iterations
570575
tol : float, optional
@@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
611616
if solver not in ['PGD', 'PPA']:
612617
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
613618

619+
if loss_fun not in ('square_loss', 'kl_loss'):
620+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
621+
614622
M, C1, C2 = list_to_array(M, C1, C2)
615623
arr = [M, C1, C2]
616624
if p is not None:
@@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
762770
q : array-like, shape (nt,), optional
763771
Distribution in the target space.
764772
If let to its default value None, uniform distribution is taken.
765-
loss_fun : string, optional
773+
loss_fun : string, optional (default='square_loss')
766774
Loss function used for the solver either 'square_loss' or 'kl_loss'
767775
epsilon : float, optional
768776
Regularization term >0
@@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
775783
G0: array-like, shape (ns,nt), optional
776784
If None the initial transport plan of the solver is pq^T.
777785
Otherwise G0 will be used as initial transport of the solver. G0 is not
778-
required to satisfy marginal constraints but we strongly recommand it
779-
to correcly estimate the GW distance.
786+
required to satisfy marginal constraints but we strongly recommend it
787+
to correctly estimate the GW distance.
780788
max_iter : int, optional
781789
Max number of iterations
782790
tol : float, optional
@@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
857865
lambdas : list of float, optional
858866
List of the `S` spaces' weights.
859867
If let to its default value None, uniform weights are taken.
860-
loss_fun : callable, optional
861-
tensor-matrix multiplication function based on specific loss function
868+
loss_fun : string, optional (default='square_loss')
869+
Loss function used for the solver either 'square_loss' or 'kl_loss'
862870
epsilon : float, optional
863871
Regularization term >0
864872
symmetric : bool, optional.
@@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
907915
"Optimal Transport for structured data with application on graphs"
908916
International Conference on Machine Learning (ICML). 2019.
909917
"""
918+
if loss_fun not in ('square_loss', 'kl_loss'):
919+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
920+
910921
Cs = list_to_array(*Cs)
911922
Ys = list_to_array(*Ys)
912923
arr = [*Cs, *Ys]
@@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(
977988

978989
if loss_fun == 'square_loss':
979990
C = update_square_loss(p, lambdas, T, Cs)
980-
981991
elif loss_fun == 'kl_loss':
982992
C = update_kl_loss(p, lambdas, T, Cs)
983993

@@ -1004,7 +1014,6 @@ def entropic_fused_gromov_barycenters(
10041014
print('{:5d}|{:8e}|'.format(cpt, err_feature))
10051015

10061016
cpt += 1
1007-
print('Y type:', type(Y))
10081017
if log:
10091018
log_['T'] = T # from target to Ys
10101019
log_['p'] = p

ot/gromov/_dictionary.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import numpy as np
1212

1313

14-
from ..utils import unif
14+
from ..utils import unif, check_random_state
1515
from ..backend import get_backend
1616
from ._gw import gromov_wasserstein, fused_gromov_wasserstein
1717

1818

1919
def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
20-
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
20+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, random_state=None, **kwargs):
2121
r"""
2222
Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
2323
@@ -81,6 +81,9 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
8181
Maximum number of iterations for the Conjugate Gradient. Default is 200.
8282
verbose : bool, optional
8383
Print the reconstruction loss every epoch. Default is False.
84+
random_state : int, RandomState instance or None, default=None
85+
Determines random number generation. Pass an int for reproducible
86+
output across multiple function calls.
8487
8588
Returns
8689
-------
@@ -90,6 +93,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
9093
The dictionary leading to the best loss over an epoch is saved and returned.
9194
log: dict
9295
If use_log is True, contains loss evolutions by batches and epochs.
96+
9397
References
9498
-------
9599
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -110,10 +114,11 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
110114
q = unif(nt)
111115
else:
112116
q = nx.to_numpy(q)
117+
rng = check_random_state(random_state)
113118
if Cdict_init is None:
114119
# Initialize randomly structures of dictionary atoms based on samples
115120
dataset_means = [C.mean() for C in Cs]
116-
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
121+
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
117122
else:
118123
Cdict = nx.to_numpy(Cdict_init).copy()
119124
assert Cdict.shape == (D, nt, nt)
@@ -141,7 +146,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
141146

142147
for _ in range(iter_by_epoch):
143148
# batch sampling
144-
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
149+
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
145150
cumulated_loss_over_batch = 0.
146151
unmixings = np.zeros((batch_size, D))
147152
Cs_embedded = np.zeros((batch_size, nt, nt))
@@ -469,7 +474,8 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons
469474

470475
def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
471476
Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
472-
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
477+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False,
478+
random_state=None, **kwargs):
473479
r"""
474480
Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
475481
@@ -548,6 +554,9 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
548554
Maximum number of iterations for the Conjugate Gradient. Default is 200.
549555
verbose : bool, optional
550556
Print the reconstruction loss every epoch. Default is False.
557+
random_state : int, RandomState instance or None, default=None
558+
Determines random number generation. Pass an int for reproducible
559+
output across multiple function calls.
551560
552561
Returns
553562
-------
@@ -560,6 +569,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
560569
The dictionary leading to the best loss over an epoch is saved and returned.
561570
log: dict
562571
If use_log is True, contains loss evolutions by batches and epochs.
572+
563573
References
564574
-------
565575
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -583,17 +593,18 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
583593
else:
584594
q = nx.to_numpy(q)
585595

596+
rng = check_random_state(random_state)
586597
if Cdict_init is None:
587598
# Initialize randomly structures of dictionary atoms based on samples
588599
dataset_means = [C.mean() for C in Cs]
589-
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
600+
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
590601
else:
591602
Cdict = nx.to_numpy(Cdict_init).copy()
592603
assert Cdict.shape == (D, nt, nt)
593604
if Ydict_init is None:
594605
# Initialize randomly features of dictionary atoms based on samples distribution by feature component
595606
dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
596-
Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
607+
Ydict = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
597608
else:
598609
Ydict = nx.to_numpy(Ydict_init).copy()
599610
assert Ydict.shape == (D, nt, d)
@@ -626,7 +637,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
626637
for _ in range(iter_by_epoch):
627638

628639
# Batch iterations
629-
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
640+
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
630641
cumulated_loss_over_batch = 0.
631642
unmixings = np.zeros((batch_size, D))
632643
Cs_embedded = np.zeros((batch_size, nt, nt))

0 commit comments

Comments
 (0)