Skip to content

Commit 20cc202

Browse files
authored
Move from random seeding to local generators (#512)
1 parent 98e3187 commit 20cc202

18 files changed

+206
-169
lines changed

ot/dr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pymanopt.optimizers
2626

2727
from .bregman import sinkhorn as sinkhorn_bregman
28-
from .utils import dist as dist_utils
28+
from .utils import dist as dist_utils, check_random_state
2929

3030

3131
def dist(x1, x2):
@@ -267,7 +267,7 @@ def proj(X):
267267
return Popt.point, proj
268268

269269

270-
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
270+
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0, random_state=None):
271271
r"""
272272
Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
273273
@@ -303,6 +303,9 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
303303
Stop threshold on error (>0)
304304
verbose : int, optional
305305
Print information along iterations.
306+
random_state : int, RandomState instance or None, default=None
307+
Determines random number generation for initial value of projection
308+
operator when U0 is not given.
306309
307310
Returns
308311
-------
@@ -332,7 +335,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
332335
assert d > k
333336

334337
if U0 is None:
335-
U = np.random.randn(d, k)
338+
rng = check_random_state(random_state)
339+
U = rng.randn(d, k)
336340
U, _ = np.linalg.qr(U)
337341
else:
338342
U = U0

ot/gromov/_dictionary.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import numpy as np
1212

1313

14-
from ..utils import unif
14+
from ..utils import unif, check_random_state
1515
from ..backend import get_backend
1616
from ._gw import gromov_wasserstein, fused_gromov_wasserstein
1717

1818

1919
def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
20-
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
20+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, random_state=None, **kwargs):
2121
r"""
2222
Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
2323
@@ -81,6 +81,9 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
8181
Maximum number of iterations for the Conjugate Gradient. Default is 200.
8282
verbose : bool, optional
8383
Print the reconstruction loss every epoch. Default is False.
84+
random_state : int, RandomState instance or None, default=None
85+
Determines random number generation. Pass an int for reproducible
86+
output across multiple function calls.
8487
8588
Returns
8689
-------
@@ -90,6 +93,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
9093
The dictionary leading to the best loss over an epoch is saved and returned.
9194
log: dict
9295
If use_log is True, contains loss evolutions by batches and epochs.
96+
9397
References
9498
-------
9599
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -110,10 +114,11 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
110114
q = unif(nt)
111115
else:
112116
q = nx.to_numpy(q)
117+
rng = check_random_state(random_state)
113118
if Cdict_init is None:
114119
# Initialize randomly structures of dictionary atoms based on samples
115120
dataset_means = [C.mean() for C in Cs]
116-
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
121+
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
117122
else:
118123
Cdict = nx.to_numpy(Cdict_init).copy()
119124
assert Cdict.shape == (D, nt, nt)
@@ -141,7 +146,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
141146

142147
for _ in range(iter_by_epoch):
143148
# batch sampling
144-
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
149+
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
145150
cumulated_loss_over_batch = 0.
146151
unmixings = np.zeros((batch_size, D))
147152
Cs_embedded = np.zeros((batch_size, nt, nt))
@@ -469,7 +474,8 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons
469474

470475
def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
471476
Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
472-
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
477+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False,
478+
random_state=None, **kwargs):
473479
r"""
474480
Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
475481
@@ -548,6 +554,9 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
548554
Maximum number of iterations for the Conjugate Gradient. Default is 200.
549555
verbose : bool, optional
550556
Print the reconstruction loss every epoch. Default is False.
557+
random_state : int, RandomState instance or None, default=None
558+
Determines random number generation. Pass an int for reproducible
559+
output across multiple function calls.
551560
552561
Returns
553562
-------
@@ -560,6 +569,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
560569
The dictionary leading to the best loss over an epoch is saved and returned.
561570
log: dict
562571
If use_log is True, contains loss evolutions by batches and epochs.
572+
563573
References
564574
-------
565575
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -583,17 +593,18 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
583593
else:
584594
q = nx.to_numpy(q)
585595

596+
rng = check_random_state(random_state)
586597
if Cdict_init is None:
587598
# Initialize randomly structures of dictionary atoms based on samples
588599
dataset_means = [C.mean() for C in Cs]
589-
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
600+
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
590601
else:
591602
Cdict = nx.to_numpy(Cdict_init).copy()
592603
assert Cdict.shape == (D, nt, nt)
593604
if Ydict_init is None:
594605
# Initialize randomly features of dictionary atoms based on samples distribution by feature component
595606
dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
596-
Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
607+
Ydict = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
597608
else:
598609
Ydict = nx.to_numpy(Ydict_init).copy()
599610
assert Ydict.shape == (D, nt, d)
@@ -626,7 +637,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
626637
for _ in range(iter_by_epoch):
627638

628639
# Batch iterations
629-
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
640+
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
630641
cumulated_loss_over_batch = 0.
631642
unmixings = np.zeros((batch_size, D))
632643
Cs_embedded = np.zeros((batch_size, nt, nt))

ot/stochastic.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# License: MIT License
1111

1212
import numpy as np
13-
from .utils import dist
13+
from .utils import dist, check_random_state
1414
from .backend import get_backend
1515

1616
##############################################################################
@@ -69,7 +69,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
6969
return b - khi
7070

7171

72-
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
72+
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):
7373
r"""
7474
Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem
7575
@@ -110,6 +110,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
110110
Number of iteration.
111111
lr : float
112112
Learning rate.
113+
random_state : int, RandomState instance or None, default=None
114+
Determines random number generation. Pass an int for reproducible
115+
output across multiple function calls.
113116
114117
Returns
115118
-------
@@ -129,8 +132,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
129132
cur_beta = np.zeros(n_target)
130133
stored_gradient = np.zeros((n_source, n_target))
131134
sum_stored_gradient = np.zeros(n_target)
135+
rng = check_random_state(random_state)
132136
for _ in range(numItermax):
133-
i = np.random.randint(n_source)
137+
i = rng.randint(n_source)
134138
cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg,
135139
cur_beta, i)
136140
sum_stored_gradient += (cur_coord_grad - stored_gradient[i])
@@ -139,7 +143,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
139143
return cur_beta
140144

141145

142-
def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
146+
def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, random_state=None):
143147
r'''
144148
Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem
145149
@@ -177,6 +181,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
177181
Number of iteration.
178182
lr : float
179183
Learning rate.
184+
random_state : int, RandomState instance or None, default=None
185+
Determines random number generation. Pass an int for reproducible
186+
output across multiple function calls.
180187
181188
Returns
182189
-------
@@ -195,9 +202,10 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
195202
n_target = np.shape(M)[1]
196203
cur_beta = np.zeros(n_target)
197204
ave_beta = np.zeros(n_target)
205+
rng = check_random_state(random_state)
198206
for cur_iter in range(numItermax):
199207
k = cur_iter + 1
200-
i = np.random.randint(n_source)
208+
i = rng.randint(n_source)
201209
cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i)
202210
cur_beta += (lr / np.sqrt(k)) * cur_coord_grad
203211
ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta
@@ -422,7 +430,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
422430
return grad_alpha, grad_beta
423431

424432

425-
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
433+
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random_state=None):
426434
r'''
427435
Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem
428436
@@ -460,6 +468,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
460468
number of iteration
461469
lr : float
462470
learning rate
471+
random_state : int, RandomState instance or None, default=None
472+
Determines random number generation. Pass an int for reproducible
473+
output across multiple function calls.
463474
464475
Returns
465476
-------
@@ -477,10 +488,11 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
477488
n_target = np.shape(M)[1]
478489
cur_alpha = np.zeros(n_source)
479490
cur_beta = np.zeros(n_target)
491+
rng = check_random_state(random_state)
480492
for cur_iter in range(numItermax):
481493
k = np.sqrt(cur_iter + 1)
482-
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
483-
batch_beta = np.random.choice(n_target, batch_size, replace=False)
494+
batch_alpha = rng.choice(n_source, batch_size, replace=False)
495+
batch_beta = rng.choice(n_target, batch_size, replace=False)
484496
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
485497
cur_beta, batch_size,
486498
batch_alpha, batch_beta)

test/test_1d_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_emd_1d_emd2_1d():
163163
np.testing.assert_allclose(G, G_1d, atol=1e-15)
164164

165165
# check AssertionError is raised if called on non 1d arrays
166-
u = np.random.randn(n, 2)
167-
v = np.random.randn(m, 2)
166+
u = rng.randn(n, 2)
167+
v = rng.randn(m, 2)
168168
with pytest.raises(AssertionError):
169169
ot.emd_1d(u, v, [], [])
170170

test/test_bregman.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def test_sinkhorn_variants(nx):
298298
def test_sinkhorn_variants_dtype_device(nx, method):
299299
n = 100
300300

301-
x = np.random.randn(n, 2)
301+
rng = np.random.RandomState(42)
302+
x = rng.randn(n, 2)
302303
u = ot.utils.unif(n)
303304

304305
M = ot.dist(x, x)
@@ -317,7 +318,8 @@ def test_sinkhorn_variants_dtype_device(nx, method):
317318
def test_sinkhorn2_variants_dtype_device(nx, method):
318319
n = 100
319320

320-
x = np.random.randn(n, 2)
321+
rng = np.random.RandomState(42)
322+
x = rng.randn(n, 2)
321323
u = ot.utils.unif(n)
322324

323325
M = ot.dist(x, x)
@@ -337,7 +339,8 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
337339
def test_sinkhorn2_variants_device_tf(method):
338340
nx = ot.backend.TensorflowBackend()
339341
n = 100
340-
x = np.random.randn(n, 2)
342+
rng = np.random.RandomState(42)
343+
x = rng.randn(n, 2)
341344
u = ot.utils.unif(n)
342345
M = ot.dist(x, x)
343346

@@ -690,11 +693,12 @@ def test_barycenter_stabilization(nx):
690693

691694
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
692695
def test_wasserstein_bary_2d(nx, method):
696+
rng = np.random.RandomState(42)
693697
size = 20 # size of a square image
694-
a1 = np.random.rand(size, size)
698+
a1 = rng.rand(size, size)
695699
a1 += a1.min()
696700
a1 = a1 / np.sum(a1)
697-
a2 = np.random.rand(size, size)
701+
a2 = rng.rand(size, size)
698702
a2 += a2.min()
699703
a2 = a2 / np.sum(a2)
700704
# creating matrix A containing all distributions
@@ -724,11 +728,12 @@ def test_wasserstein_bary_2d(nx, method):
724728

725729
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
726730
def test_wasserstein_bary_2d_debiased(nx, method):
731+
rng = np.random.RandomState(42)
727732
size = 20 # size of a square image
728-
a1 = np.random.rand(size, size)
733+
a1 = rng.rand(size, size)
729734
a1 += a1.min()
730735
a1 = a1 / np.sum(a1)
731-
a2 = np.random.rand(size, size)
736+
a2 = rng.rand(size, size)
732737
a2 += a2.min()
733738
a2 = a2 / np.sum(a2)
734739
# creating matrix A containing all distributions

test/test_coot.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,22 @@ def test_coot_warmstart(nx):
223223
xt_nx = nx.from_numpy(xt)
224224

225225
# initialize warmstart
226-
init_pi_sample = np.random.rand(n_samples, n_samples)
226+
rng = np.random.RandomState(42)
227+
init_pi_sample = rng.rand(n_samples, n_samples)
227228
init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
228229
init_pi_sample_nx = nx.from_numpy(init_pi_sample)
229230

230-
init_pi_feature = np.random.rand(2, 2)
231+
init_pi_feature = rng.rand(2, 2)
231232
init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
232233
init_pi_feature_nx = nx.from_numpy(init_pi_feature)
233234

234-
init_duals_sample = (np.random.random(n_samples) * 2 - 1,
235-
np.random.random(n_samples) * 2 - 1)
235+
init_duals_sample = (rng.random(n_samples) * 2 - 1,
236+
rng.random(n_samples) * 2 - 1)
236237
init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
237238
nx.from_numpy(init_duals_sample[1]))
238239

239-
init_duals_feature = (np.random.random(2) * 2 - 1,
240-
np.random.random(2) * 2 - 1)
240+
init_duals_feature = (rng.random(2) * 2 - 1,
241+
rng.random(2) * 2 - 1)
241242
init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
242243
nx.from_numpy(init_duals_feature[1]))
243244

test/test_da.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,11 @@ def test_mapping_transport_class_specific_seed(nx):
567567
# check that it does not crash when derphi is very close to 0
568568
ns = 20
569569
nt = 30
570-
np.random.seed(39)
571-
Xs, ys = make_data_classif('3gauss', ns)
572-
Xt, yt = make_data_classif('3gauss2', nt)
570+
rng = np.random.RandomState(39)
571+
Xs, ys = make_data_classif('3gauss', ns, random_state=rng)
572+
Xt, yt = make_data_classif('3gauss2', nt, random_state=rng)
573573
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
574574
otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt))
575-
np.random.seed(None)
576575

577576

578577
@pytest.skip_backend("jax")
@@ -712,7 +711,6 @@ def test_jcpot_barycenter(nx):
712711
nt = 50
713712

714713
sigma = 0.1
715-
np.random.seed(1985)
716714

717715
ps1 = .2
718716
ps2 = .9

test/test_dmmot.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
def create_test_data(nx):
13-
np.random.seed(1234)
1413
n = 4
1514
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)
1615
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

0 commit comments

Comments
 (0)