Skip to content

[WIP] low rank sinkhorn, solve_sample, OTResultLazy + test functions #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import factored
from . import solvers
from . import gaussian
from . import lowrank

# OT functions
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
Expand All @@ -50,7 +51,8 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov
from .solvers import solve, solve_gromov, solve_sample
from .lowrank import lowrank_sinkhorn

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand Down
210 changes: 210 additions & 0 deletions ot/lowrank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#################################################################################################################
############################################## WORK IN PROGRESS #################################################
#################################################################################################################

## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms

import warnings
from .utils import unif, list_to_array, dist
from .backend import get_backend



################################## LR-DYSKTRA ALGORITHM ##########################################

def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit more documentation with reeference to the paper and to the algorithm number in the paper please

"""
Implementation of the Dykstra algorithm for low rank sinkhorn
"""

# get dykstra parameters
q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p

# POT backend
eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needs LR_dykstra is used inside POT function and will never be interfaced: i will receive arrays not lists

q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2)

nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done only when nx is not given to the function and set to None, else use the new from the calling function (you need to add nx as a parameter to the function)


# ------- Dykstra algorithm ------
g_ = eps3

u1 = p1 / nx.dot(eps1, v1_)
u2 = p2 / nx.dot(eps2, v2_)

g = nx.maximum(alpha, g_ * q3_1)
q3_1 = (g_ * q3_1) / g
g_ = g

prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1))
prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2))
g = (g_ * q3_2 * prod1 * prod2)**(1/3)

v1 = g / nx.dot(eps1.T,u1)
v2 = g / nx.dot(eps2.T,u2)

q1 = (v1_ * q1) / v1
q2 = (v2_ * q2) / v2
q3_2 = (g_ * q3_2) / g

v1_, v2_ = v1, v2
g_ = g

# Compute error
err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1))
err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2))
err = err1 + err2

# Compute low rank matrices Q, R
Q = u1[:,None] * eps1 * v1[None,:]
R = u2[:,None] * eps2 * v2[None,:]

dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2]

return Q, R, g, err, dykstra_p



#################################### LOW RANK SINKHORN ALGORITHM #########################################


def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto",
numItermax=10000, stopThr=1e-9, warn=True, verbose=False):
r'''
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints.

This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g.

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
reg : float, optional
Regularization term >0
rank: int, optional
Nonnegative rank of the OT plan
alpha: int, optional
Lower bound for the weight vector g (>0 and <1/r)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)


Returns
-------
Q : array-like, shape (n_samples_a, r)
First low-rank matrix decomposition of the OT plan
R: array-like, shape (n_samples_b, r)
Second low-rank matrix decomposition of the OT plan
g : array-like, shape (r, )
Weight vector for the low-rank decomposition of the OT plan


References
----------

.. Scetbon, M., Cuturi, M., & Peyré, G (2021).
Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737.

'''

X_s, X_t = list_to_array(X_s, X_t)
nx = get_backend(X_s, X_t)

ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a = nx.from_numpy(unif(ns), type_as=X_s)
a = unif(ns, type_as=X_s)

if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


# Compute cost matrix
M = dist(X_s,X_t, metric=metric)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M should never be computed as an ntimes n marix It shoud be stired as afctorized version (see discussion on the paper with D=AB) and used only as factorized version later (when computing dot ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fcat I'm OK with implementin only squared_euclidena and raise NotImplemented for other metrics until we have afficent function for obtainin low rank factorization of the metrics


# Compute rank
rank = min(ns, nt, rank)
r = rank

if alpha == 'auto':
alpha = 1.0 / (r + 1)

if (1/r < alpha) or (alpha < 0):
warnings.warn("The provided alpha value might lead to instabilities.")


# Compute gamma
L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2))
gamma = 1/(2*L)

# Initialisation
Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r)
q3_1, q3_2 = nx.ones(r), nx.ones(r)
v1_, v2_ = nx.ones(r), nx.ones(r)
q1, q2 = nx.ones(r), nx.ones(r)
dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2]
err = 1

for ii in range(numItermax):
CR = nx.dot(M,R)
C_t_Q = nx.dot(M.T,Q)
diag_g = (1/g)[:,None]

eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adda efw more comments with references to equations and alg line nume rin the paper please

eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R))
omega = nx.diag(nx.dot(Q.T, CR))
eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g))

Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p)

if err < stopThr:
break

if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")

return Q, R, g





############################################################################
## Test with X_s, X_t from ot.datasets
#############################################################################

# import numpy as np
# import ot

# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000)
# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500)

# ns = Xs.shape[0]
# nt = Xt.shape[0]

# a = unif(ns)
# b = unif(nt)

# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100)
# M = ot.dist(Xs,Xt)
# P = np.dot(Q,np.dot(np.diag(1/g),R.T))

# print(np.sum(P))




166 changes: 166 additions & 0 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,169 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx)

return res






################################## WORK IN PROGRESS #####################################

## Implementation of the ot.solve_sample function
## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases)


from .utils import unif, list_to_array, dist, OTResultLazy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those should be at the top (move it when nearing merge)

from .bregman import empirical_sinkhorn


def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):

r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
It returns either a :any:`OTResult` or :any:`OTResultLazy` object.

The function solves the following general optimal transport problem

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and
`unbalanced_type`. By default ``unbalanced=None`` and the function
solves the exact optimal transport problem (respecting the marginals).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL"
is_Lazy : bool, optional
Return :any:`OTResultlazy` object to reduce memory cost when True, by default False
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False

Returns
-------

res_lazy : OTResultLazy()
Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs.
The information can be obtained as follows:

- res.lazy_plan : OT plan computed on a subsample of X_s and X_t
- res.potentials : OT dual potentials

See :any:`OTResultLazy` for more information.

res : OTResult()
Result of the optimization problem. The information can be obtained as follows:

- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan

See :any:`OTResult` for more information.


"""

X_s, X_t = list_to_array(X_s,X_t)

# detect backend
arr = [X_s,X_t]
if a is not None:
arr.append(a)
if b is not None:
arr.append(b)
nx = get_backend(*arr)

# create uniform weights if not given
ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

# default values for solutions
potentials = None
lazy_plan = None

if max_iter is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move those before calling empirical_sinkhorn function those values are relevant to this function

max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100

if is_Lazy:
################# WIP ####################
if reg is None or reg == 0: # EMD solver for isLazy ?

if unbalanced is None: # balanced EMD solver for isLazy ?
raise (NotImplementedError('Not implemented balanced with no regularization'))

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type)))


#############################################

else:
if unbalanced is None:
u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol,
isLazy=True, batchSize=batch_size, verbose=verbose, log=True)
# compute potentials
potentials = (log["u"], log["v"])

# compute lazy_plan
ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan)
M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric)
K = nx.exp(M / (-reg))
lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1))

res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx)
return res_lazy

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type)))

else:
# compute cost matrix M and use solve function
M = dist(X_s, X_t, metric)

res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose)
return res




Loading