Skip to content

Overhaul random state management #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pymanopt.optimizers

from .bregman import sinkhorn as sinkhorn_bregman
from .utils import dist as dist_utils
from .utils import dist as dist_utils, check_random_state


def dist(x1, x2):
Expand Down Expand Up @@ -267,7 +267,7 @@ def proj(X):
return Popt.point, proj


def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0, random_state=None):
r"""
Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`

Expand Down Expand Up @@ -303,6 +303,9 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
Stop threshold on error (>0)
verbose : int, optional
Print information along iterations.
random_state : int, RandomState instance or None, default=None
Determines random number generation for initial value of projection
operator when U0 is not given.

Returns
-------
Expand Down Expand Up @@ -332,7 +335,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
assert d > k

if U0 is None:
U = np.random.randn(d, k)
rng = check_random_state(random_state)
U = rng.randn(d, k)
U, _ = np.linalg.qr(U)
else:
U = U0
Expand Down
27 changes: 19 additions & 8 deletions ot/gromov/_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import numpy as np


from ..utils import unif
from ..utils import unif, check_random_state
from ..backend import get_backend
from ._gw import gromov_wasserstein, fused_gromov_wasserstein


def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, random_state=None, **kwargs):
r"""
Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`

Expand Down Expand Up @@ -81,6 +81,9 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
Maximum number of iterations for the Conjugate Gradient. Default is 200.
verbose : bool, optional
Print the reconstruction loss every epoch. Default is False.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.

Returns
-------
Expand All @@ -90,6 +93,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
The dictionary leading to the best loss over an epoch is saved and returned.
log: dict
If use_log is True, contains loss evolutions by batches and epochs.

References
-------
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
Expand All @@ -110,10 +114,11 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
q = unif(nt)
else:
q = nx.to_numpy(q)
rng = check_random_state(random_state)
if Cdict_init is None:
# Initialize randomly structures of dictionary atoms based on samples
dataset_means = [C.mean() for C in Cs]
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
else:
Cdict = nx.to_numpy(Cdict_init).copy()
assert Cdict.shape == (D, nt, nt)
Expand Down Expand Up @@ -141,7 +146,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e

for _ in range(iter_by_epoch):
# batch sampling
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
cumulated_loss_over_batch = 0.
unmixings = np.zeros((batch_size, D))
Cs_embedded = np.zeros((batch_size, nt, nt))
Expand Down Expand Up @@ -469,7 +474,8 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons

def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False,
random_state=None, **kwargs):
r"""
Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`

Expand Down Expand Up @@ -548,6 +554,9 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
Maximum number of iterations for the Conjugate Gradient. Default is 200.
verbose : bool, optional
Print the reconstruction loss every epoch. Default is False.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.

Returns
-------
Expand All @@ -560,6 +569,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
The dictionary leading to the best loss over an epoch is saved and returned.
log: dict
If use_log is True, contains loss evolutions by batches and epochs.

References
-------
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
Expand All @@ -583,17 +593,18 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
else:
q = nx.to_numpy(q)

rng = check_random_state(random_state)
if Cdict_init is None:
# Initialize randomly structures of dictionary atoms based on samples
dataset_means = [C.mean() for C in Cs]
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
else:
Cdict = nx.to_numpy(Cdict_init).copy()
assert Cdict.shape == (D, nt, nt)
if Ydict_init is None:
# Initialize randomly features of dictionary atoms based on samples distribution by feature component
dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
Ydict = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
else:
Ydict = nx.to_numpy(Ydict_init).copy()
assert Ydict.shape == (D, nt, d)
Expand Down Expand Up @@ -626,7 +637,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
for _ in range(iter_by_epoch):

# Batch iterations
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
cumulated_loss_over_batch = 0.
unmixings = np.zeros((batch_size, D))
Cs_embedded = np.zeros((batch_size, nt, nt))
Expand Down
28 changes: 20 additions & 8 deletions ot/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# License: MIT License

import numpy as np
from .utils import dist
from .utils import dist, check_random_state
from .backend import get_backend

##############################################################################
Expand Down Expand Up @@ -69,7 +69,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
return b - khi


def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):
r"""
Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem

Expand Down Expand Up @@ -110,6 +110,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Number of iteration.
lr : float
Learning rate.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.

Returns
-------
Expand All @@ -129,8 +132,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
cur_beta = np.zeros(n_target)
stored_gradient = np.zeros((n_source, n_target))
sum_stored_gradient = np.zeros(n_target)
rng = check_random_state(random_state)
for _ in range(numItermax):
i = np.random.randint(n_source)
i = rng.randint(n_source)
cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg,
cur_beta, i)
sum_stored_gradient += (cur_coord_grad - stored_gradient[i])
Expand All @@ -139,7 +143,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
return cur_beta


def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, random_state=None):
r'''
Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem

Expand Down Expand Up @@ -177,6 +181,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Number of iteration.
lr : float
Learning rate.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.

Returns
-------
Expand All @@ -195,9 +202,10 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
ave_beta = np.zeros(n_target)
rng = check_random_state(random_state)
for cur_iter in range(numItermax):
k = cur_iter + 1
i = np.random.randint(n_source)
i = rng.randint(n_source)
cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i)
cur_beta += (lr / np.sqrt(k)) * cur_coord_grad
ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta
Expand Down Expand Up @@ -422,7 +430,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
return grad_alpha, grad_beta


def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random_state=None):
r'''
Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem

Expand Down Expand Up @@ -460,6 +468,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
number of iteration
lr : float
learning rate
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.

Returns
-------
Expand All @@ -477,10 +488,11 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
n_target = np.shape(M)[1]
cur_alpha = np.zeros(n_source)
cur_beta = np.zeros(n_target)
rng = check_random_state(random_state)
for cur_iter in range(numItermax):
k = np.sqrt(cur_iter + 1)
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
batch_beta = np.random.choice(n_target, batch_size, replace=False)
batch_alpha = rng.choice(n_source, batch_size, replace=False)
batch_beta = rng.choice(n_target, batch_size, replace=False)
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
Expand Down
4 changes: 2 additions & 2 deletions test/test_1d_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(G, G_1d, atol=1e-15)

# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
v = np.random.randn(m, 2)
u = rng.randn(n, 2)
v = rng.randn(m, 2)
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])

Expand Down
19 changes: 12 additions & 7 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def test_sinkhorn_variants(nx):
def test_sinkhorn_variants_dtype_device(nx, method):
n = 100

x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)

M = ot.dist(x, x)
Expand All @@ -317,7 +318,8 @@ def test_sinkhorn_variants_dtype_device(nx, method):
def test_sinkhorn2_variants_dtype_device(nx, method):
n = 100

x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)

M = ot.dist(x, x)
Expand All @@ -337,7 +339,8 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
def test_sinkhorn2_variants_device_tf(method):
nx = ot.backend.TensorflowBackend()
n = 100
x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)

Expand Down Expand Up @@ -690,11 +693,12 @@ def test_barycenter_stabilization(nx):

@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
def test_wasserstein_bary_2d(nx, method):
rng = np.random.RandomState(42)
size = 20 # size of a square image
a1 = np.random.rand(size, size)
a1 = rng.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
a2 = np.random.rand(size, size)
a2 = rng.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
Expand Down Expand Up @@ -724,11 +728,12 @@ def test_wasserstein_bary_2d(nx, method):

@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
def test_wasserstein_bary_2d_debiased(nx, method):
rng = np.random.RandomState(42)
size = 20 # size of a square image
a1 = np.random.rand(size, size)
a1 = rng.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
a2 = np.random.rand(size, size)
a2 = rng.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
Expand Down
13 changes: 7 additions & 6 deletions test/test_coot.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,22 @@ def test_coot_warmstart(nx):
xt_nx = nx.from_numpy(xt)

# initialize warmstart
init_pi_sample = np.random.rand(n_samples, n_samples)
rng = np.random.RandomState(42)
init_pi_sample = rng.rand(n_samples, n_samples)
init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
init_pi_sample_nx = nx.from_numpy(init_pi_sample)

init_pi_feature = np.random.rand(2, 2)
init_pi_feature = rng.rand(2, 2)
init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
init_pi_feature_nx = nx.from_numpy(init_pi_feature)

init_duals_sample = (np.random.random(n_samples) * 2 - 1,
np.random.random(n_samples) * 2 - 1)
init_duals_sample = (rng.random(n_samples) * 2 - 1,
rng.random(n_samples) * 2 - 1)
init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
nx.from_numpy(init_duals_sample[1]))

init_duals_feature = (np.random.random(2) * 2 - 1,
np.random.random(2) * 2 - 1)
init_duals_feature = (rng.random(2) * 2 - 1,
rng.random(2) * 2 - 1)
init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
nx.from_numpy(init_duals_feature[1]))

Expand Down
8 changes: 3 additions & 5 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,11 @@ def test_mapping_transport_class_specific_seed(nx):
# check that it does not crash when derphi is very close to 0
ns = 20
nt = 30
np.random.seed(39)
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
rng = np.random.RandomState(39)
Xs, ys = make_data_classif('3gauss', ns, random_state=rng)
Xt, yt = make_data_classif('3gauss2', nt, random_state=rng)
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt))
np.random.seed(None)


@pytest.skip_backend("jax")
Expand Down Expand Up @@ -712,7 +711,6 @@ def test_jcpot_barycenter(nx):
nt = 50

sigma = 0.1
np.random.seed(1985)

ps1 = .2
ps2 = .9
Expand Down
1 change: 0 additions & 1 deletion test/test_dmmot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


def create_test_data(nx):
np.random.seed(1234)
n = 4
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
Expand Down
Loading