-
Notifications
You must be signed in to change notification settings - Fork 524
[WIP] low rank sinkhorn, solve_sample, OTResultLazy + test functions #542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f49f6b4
3c4b50f
3034e57
085863a
ca07641
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,210 @@ | ||||||
################################################################################################################# | ||||||
############################################## WORK IN PROGRESS ################################################# | ||||||
################################################################################################################# | ||||||
|
||||||
## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms | ||||||
|
||||||
import warnings | ||||||
from .utils import unif, list_to_array, dist | ||||||
from .backend import get_backend | ||||||
|
||||||
|
||||||
|
||||||
################################## LR-DYSKTRA ALGORITHM ########################################## | ||||||
|
||||||
def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): | ||||||
""" | ||||||
Implementation of the Dykstra algorithm for low rank sinkhorn | ||||||
""" | ||||||
|
||||||
# get dykstra parameters | ||||||
q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p | ||||||
|
||||||
# POT backend | ||||||
eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needs LR_dykstra is used inside POT function and will never be interfaced: i will receive arrays not lists |
||||||
q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) | ||||||
|
||||||
nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be done only when nx is not given to the function and set to None, else use the new from the calling function (you need to add nx as a parameter to the function) |
||||||
|
||||||
# ------- Dykstra algorithm ------ | ||||||
g_ = eps3 | ||||||
|
||||||
u1 = p1 / nx.dot(eps1, v1_) | ||||||
u2 = p2 / nx.dot(eps2, v2_) | ||||||
|
||||||
g = nx.maximum(alpha, g_ * q3_1) | ||||||
q3_1 = (g_ * q3_1) / g | ||||||
g_ = g | ||||||
|
||||||
prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) | ||||||
prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) | ||||||
g = (g_ * q3_2 * prod1 * prod2)**(1/3) | ||||||
|
||||||
v1 = g / nx.dot(eps1.T,u1) | ||||||
v2 = g / nx.dot(eps2.T,u2) | ||||||
|
||||||
q1 = (v1_ * q1) / v1 | ||||||
q2 = (v2_ * q2) / v2 | ||||||
q3_2 = (g_ * q3_2) / g | ||||||
|
||||||
v1_, v2_ = v1, v2 | ||||||
g_ = g | ||||||
|
||||||
# Compute error | ||||||
err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) | ||||||
err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) | ||||||
err = err1 + err2 | ||||||
|
||||||
# Compute low rank matrices Q, R | ||||||
Q = u1[:,None] * eps1 * v1[None,:] | ||||||
R = u2[:,None] * eps2 * v2[None,:] | ||||||
|
||||||
dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] | ||||||
|
||||||
return Q, R, g, err, dykstra_p | ||||||
|
||||||
|
||||||
|
||||||
#################################### LOW RANK SINKHORN ALGORITHM ######################################### | ||||||
|
||||||
|
||||||
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", | ||||||
numItermax=10000, stopThr=1e-9, warn=True, verbose=False): | ||||||
r''' | ||||||
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. | ||||||
|
||||||
This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
X_s : array-like, shape (n_samples_a, dim) | ||||||
samples in the source domain | ||||||
X_t : array-like, shape (n_samples_b, dim) | ||||||
samples in the target domain | ||||||
a : array-like, shape (n_samples_a,) | ||||||
samples weights in the source domain | ||||||
b : array-like, shape (n_samples_b,) | ||||||
samples weights in the target domain | ||||||
reg : float, optional | ||||||
Regularization term >0 | ||||||
rank: int, optional | ||||||
Nonnegative rank of the OT plan | ||||||
alpha: int, optional | ||||||
Lower bound for the weight vector g (>0 and <1/r) | ||||||
numItermax : int, optional | ||||||
Max number of iterations | ||||||
stopThr : float, optional | ||||||
Stop threshold on error (>0) | ||||||
|
||||||
|
||||||
Returns | ||||||
------- | ||||||
Q : array-like, shape (n_samples_a, r) | ||||||
First low-rank matrix decomposition of the OT plan | ||||||
R: array-like, shape (n_samples_b, r) | ||||||
Second low-rank matrix decomposition of the OT plan | ||||||
g : array-like, shape (r, ) | ||||||
Weight vector for the low-rank decomposition of the OT plan | ||||||
|
||||||
|
||||||
References | ||||||
---------- | ||||||
|
||||||
.. Scetbon, M., Cuturi, M., & Peyré, G (2021). | ||||||
Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. | ||||||
|
||||||
''' | ||||||
|
||||||
X_s, X_t = list_to_array(X_s, X_t) | ||||||
nx = get_backend(X_s, X_t) | ||||||
|
||||||
ns, nt = X_s.shape[0], X_t.shape[0] | ||||||
if a is None: | ||||||
a = nx.from_numpy(unif(ns), type_as=X_s) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if b is None: | ||||||
b = nx.from_numpy(unif(nt), type_as=X_s) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||||||
|
||||||
# Compute cost matrix | ||||||
M = dist(X_s,X_t, metric=metric) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. M should never be computed as an ntimes n marix It shoud be stired as afctorized version (see discussion on the paper with D=AB) and used only as factorized version later (when computing dot ). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in fcat I'm OK with implementin only squared_euclidena and raise NotImplemented for other metrics until we have afficent function for obtainin low rank factorization of the metrics |
||||||
|
||||||
# Compute rank | ||||||
rank = min(ns, nt, rank) | ||||||
r = rank | ||||||
|
||||||
if alpha == 'auto': | ||||||
alpha = 1.0 / (r + 1) | ||||||
|
||||||
if (1/r < alpha) or (alpha < 0): | ||||||
warnings.warn("The provided alpha value might lead to instabilities.") | ||||||
|
||||||
|
||||||
# Compute gamma | ||||||
L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) | ||||||
gamma = 1/(2*L) | ||||||
|
||||||
# Initialisation | ||||||
Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) | ||||||
q3_1, q3_2 = nx.ones(r), nx.ones(r) | ||||||
v1_, v2_ = nx.ones(r), nx.ones(r) | ||||||
q1, q2 = nx.ones(r), nx.ones(r) | ||||||
dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] | ||||||
err = 1 | ||||||
|
||||||
for ii in range(numItermax): | ||||||
CR = nx.dot(M,R) | ||||||
C_t_Q = nx.dot(M.T,Q) | ||||||
diag_g = (1/g)[:,None] | ||||||
|
||||||
eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adda efw more comments with references to equations and alg line nume rin the paper please |
||||||
eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) | ||||||
omega = nx.diag(nx.dot(Q.T, CR)) | ||||||
eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) | ||||||
|
||||||
Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) | ||||||
|
||||||
if err < stopThr: | ||||||
break | ||||||
|
||||||
if verbose: | ||||||
if ii % 200 == 0: | ||||||
print( | ||||||
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) | ||||||
print('{:5d}|{:8e}|'.format(ii, err)) | ||||||
|
||||||
else: | ||||||
if warn: | ||||||
warnings.warn("Sinkhorn did not converge. You might want to " | ||||||
"increase the number of iterations `numItermax` " | ||||||
"or the regularization parameter `reg`.") | ||||||
|
||||||
return Q, R, g | ||||||
|
||||||
|
||||||
|
||||||
|
||||||
|
||||||
############################################################################ | ||||||
## Test with X_s, X_t from ot.datasets | ||||||
############################################################################# | ||||||
|
||||||
# import numpy as np | ||||||
# import ot | ||||||
|
||||||
# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) | ||||||
# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) | ||||||
|
||||||
# ns = Xs.shape[0] | ||||||
# nt = Xt.shape[0] | ||||||
|
||||||
# a = unif(ns) | ||||||
# b = unif(nt) | ||||||
|
||||||
# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) | ||||||
# M = ot.dist(Xs,Xt) | ||||||
# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) | ||||||
|
||||||
# print(np.sum(P)) | ||||||
|
||||||
|
||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -848,3 +848,169 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, | |
value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) | ||
|
||
return res | ||
|
||
|
||
|
||
|
||
|
||
|
||
################################## WORK IN PROGRESS ##################################### | ||
|
||
## Implementation of the ot.solve_sample function | ||
## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) | ||
|
||
|
||
from .utils import unif, list_to_array, dist, OTResultLazy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those should be at the top (move it when nearing merge) |
||
from .bregman import empirical_sinkhorn | ||
|
||
|
||
def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, | ||
unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, | ||
potentials_init=None, tol=None, verbose=False): | ||
|
||
r"""Solve the discrete optimal transport problem using the samples in the source and target domains. | ||
It returns either a :any:`OTResult` or :any:`OTResultLazy` object. | ||
|
||
The function solves the following general optimal transport problem | ||
|
||
.. math:: | ||
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + | ||
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + | ||
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) | ||
|
||
The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By | ||
default ``reg=None`` and there is no regularization. The unbalanced marginal | ||
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and | ||
`unbalanced_type`. By default ``unbalanced=None`` and the function | ||
solves the exact optimal transport problem (respecting the marginals). | ||
|
||
Parameters | ||
---------- | ||
X_s : array-like, shape (n_samples_a, dim) | ||
samples in the source domain | ||
X_t : array-like, shape (n_samples_b, dim) | ||
samples in the target domain | ||
a : array-like, shape (dim_a,), optional | ||
Samples weights in the source domain (default is uniform) | ||
b : array-like, shape (dim_b,), optional | ||
Samples weights in the source domain (default is uniform) | ||
reg : float, optional | ||
Regularization weight :math:`\lambda_r`, by default None (no reg., exact | ||
OT) | ||
reg_type : str, optional | ||
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" | ||
unbalanced : float, optional | ||
Unbalanced penalization weight :math:`\lambda_u`, by default None | ||
(balanced OT) | ||
unbalanced_type : str, optional | ||
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" | ||
is_Lazy : bool, optional | ||
Return :any:`OTResultlazy` object to reduce memory cost when True, by default False | ||
n_threads : int, optional | ||
Number of OMP threads for exact OT solver, by default 1 | ||
max_iter : int, optional | ||
Maximum number of iteration, by default None (default values in each solvers) | ||
plan_init : array_like, shape (dim_a, dim_b), optional | ||
Initialization of the OT plan for iterative methods, by default None | ||
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional | ||
Initialization of the OT dual potentials for iterative methods, by default None | ||
tol : _type_, optional | ||
Tolerance for solution precision, by default None (default values in each solvers) | ||
verbose : bool, optional | ||
Print information in the solver, by default False | ||
|
||
Returns | ||
------- | ||
|
||
res_lazy : OTResultLazy() | ||
Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. | ||
The information can be obtained as follows: | ||
|
||
- res.lazy_plan : OT plan computed on a subsample of X_s and X_t | ||
- res.potentials : OT dual potentials | ||
|
||
See :any:`OTResultLazy` for more information. | ||
|
||
res : OTResult() | ||
Result of the optimization problem. The information can be obtained as follows: | ||
|
||
- res.plan : OT plan :math:`\mathbf{T}` | ||
- res.potentials : OT dual potentials | ||
- res.value : Optimal value of the optimization problem | ||
- res.value_linear : Linear OT loss with the optimal OT plan | ||
|
||
See :any:`OTResult` for more information. | ||
|
||
|
||
""" | ||
|
||
X_s, X_t = list_to_array(X_s,X_t) | ||
|
||
# detect backend | ||
arr = [X_s,X_t] | ||
if a is not None: | ||
arr.append(a) | ||
if b is not None: | ||
arr.append(b) | ||
nx = get_backend(*arr) | ||
|
||
# create uniform weights if not given | ||
ns, nt = X_s.shape[0], X_t.shape[0] | ||
if a is None: | ||
a = nx.from_numpy(unif(ns), type_as=X_s) | ||
if b is None: | ||
b = nx.from_numpy(unif(nt), type_as=X_s) | ||
|
||
# default values for solutions | ||
potentials = None | ||
lazy_plan = None | ||
|
||
if max_iter is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move those before calling empirical_sinkhorn function those values are relevant to this function |
||
max_iter = 1000 | ||
if tol is None: | ||
tol = 1e-9 | ||
if batch_size is None: | ||
batch_size = 100 | ||
|
||
if is_Lazy: | ||
################# WIP #################### | ||
if reg is None or reg == 0: # EMD solver for isLazy ? | ||
|
||
if unbalanced is None: # balanced EMD solver for isLazy ? | ||
raise (NotImplementedError('Not implemented balanced with no regularization')) | ||
|
||
else: | ||
raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) | ||
|
||
|
||
############################################# | ||
|
||
else: | ||
if unbalanced is None: | ||
u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, | ||
isLazy=True, batchSize=batch_size, verbose=verbose, log=True) | ||
# compute potentials | ||
potentials = (log["u"], log["v"]) | ||
|
||
# compute lazy_plan | ||
ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) | ||
M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) | ||
K = nx.exp(M / (-reg)) | ||
lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) | ||
|
||
res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) | ||
return res_lazy | ||
|
||
else: | ||
raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) | ||
|
||
else: | ||
# compute cost matrix M and use solve function | ||
M = dist(X_s, X_t, metric) | ||
|
||
res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) | ||
return res | ||
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit more documentation with reeference to the paper and to the algorithm number in the paper please