From e903625db204822f642d8fd4077001568c53eeaf Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 31 Oct 2023 23:08:47 +0100 Subject: [PATCH 01/18] add new features to unbalanced solvers --- ot/unbalanced.py | 229 +++++++++++++++++++++++++--------------- test/test_unbalanced.py | 90 +++++++++++++--- 2 files changed, 221 insertions(+), 98 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 265006d2c..6c30a88bb 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -19,7 +19,8 @@ from .utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, + method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -39,7 +40,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -66,8 +67,17 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors).s method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations @@ -134,30 +144,30 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - numItermax=1000, stopThr=1e-6, verbose=False, - log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, + method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -175,7 +185,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -202,8 +212,17 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations @@ -263,29 +282,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', b = b[:, None] if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -304,7 +324,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -330,6 +350,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -404,15 +433,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, 1), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - K = nx.exp(M / (-reg)) + if reg_type == "kl": + K = nx.exp(-M / reg) * a[:, None] * b[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -474,9 +509,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, tau=1e5, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -496,7 +532,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -523,6 +559,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). tau : float threshold for max value in u or v for log scaling numItermax : int, optional @@ -597,16 +642,21 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - # print(reg) - K = nx.exp(-M / reg) + if reg_type == "kl": + K = nx.exp(-M / reg) * a[:, None] * b[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -1074,7 +1124,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, raise ValueError("Unknown method '%s'." % method) -def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1084,7 +1134,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1094,6 +1144,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1113,8 +1164,11 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) - Entropy regularization term >= 0. + Regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1172,36 +1226,33 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b - if G0 is None: - G = a[:, None] * b[None, :] - else: - G = G0 + G = a[:, None] * b[None, :] if G0 is None else G0 + c = a[:, None] * b[None, :] if c is None else c reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': [], 'G': []} - if div == 'kl': - sum_r = reg + reg_m1 + reg_m2 - r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) - elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * a[:, None] * b[None, :] - M - K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) - else: + if div not in ["kl", "l2"]: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' + + if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) + K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) + elif div == 'l2': + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) for i in range(numItermax): Gprev = G if div == 'kl': - G = K * G**(r1 + r2) / (nx.sum(G, 1, keepdims=True)**r1 * nx.sum(G, 0, keepdims=True)**r2 + 1e-16) + Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 + G = K * G**(r1 + r2) / Gd elif div == 'l2': Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 @@ -1223,7 +1274,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return G -def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1233,7 +1284,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1243,6 +1294,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1264,6 +1316,9 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1307,7 +1362,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, reg=reg, div=div, G0=G0, + _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) @@ -1317,7 +1372,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return log_mm['cost'] -def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT @@ -1326,25 +1381,25 @@ def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='k m, n = M.shape def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - p.sum() + q.sum() + return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) def reg_l2(G): - return np.sum((G - a[:, None] * b[None, :])**2) / 2 + return np.sum((G - c)**2) / 2 def grad_l2(G): - return G - a[:, None] * b[None, :] + return G - c def reg_kl(G): - return kl(G, a[:, None] * b[None, :]) + return kl(G, c) def grad_kl(G): - return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + return np.log(G / c + 1e-16) def reg_entropy(G): - return np.sum(G * np.log(G + 1e-16)) + return np.sum(G * np.log(G + 1e-16)) - np.sum(G) def grad_entropy(G): - return np.log(G + 1e-16) + 1 + return np.log(G + 1e-16) if reg_div == 'kl': reg_fun = reg_kl @@ -1392,7 +1447,7 @@ def _func(G): return _func -def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, +def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. @@ -1400,7 +1455,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -1412,6 +1467,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize @@ -1426,6 +1482,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, loss matrix reg: float regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, @@ -1433,7 +1492,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic). regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1480,21 +1540,18 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) - + if c is None: + c = a[:, None] * b[None, :] + M, a, b, c = list_to_array(M, a, b, c) + nx = get_backend(M, a, b, c) M0 = M + # convert to numpy - a, b, M = nx.to_numpy(a, b, M) + a, b, c, M = nx.to_numpy(a, b, c, M) + G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) reg_m1, reg_m2 = get_parameter_pair(reg_m) - - if G0 is not None: - G0 = nx.to_numpy(G0) - else: - G0 = np.zeros(M.shape) - - _func = _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div, regm_div) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 272794cb8..a5dceede9 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,8 +14,9 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -25,29 +26,32 @@ def test_unbalanced_convergence(nx, method): # make dists unbalanced b = ot.utils.unif(n) * 1.5 - M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + epsilon = 1. reg_m = 1. - a, b, M = nx.from_numpy(a, b, M) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, + reg_m=reg_m,reg_type=reg_type, method=method, log=True, verbose=True) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method=method, verbose=True + a, b, M, epsilon, reg_m, reg_type=reg_type, method=method, verbose=True )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - + if reg_type == "entropy": + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + elif reg_type == "kl": + log_ab = loga[:, None] + logb[None, :] + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -65,10 +69,12 @@ def test_unbalanced_convergence(nx, method): a, b = nx.from_numpy(a_np, b_np) G = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, + reg_type=reg_type, method=method, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( - a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, + reg_type=reg_type, method=method, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) @@ -394,6 +400,31 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + @pytest.mark.parametrize("div", ["kl", "l2"]) def test_mm_convergence(nx, div): n = 100 @@ -483,6 +514,41 @@ def test_mm_relaxation_parameters(nx, div): np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + def test_mm_wrong_divergence(nx): n = 100 From ff3b057ed6e2c8de4c36f19ee2441dcc788c4827 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 31 Oct 2023 23:10:07 +0100 Subject: [PATCH 02/18] add new features to unbalanced solvers --- ot/unbalanced.py | 229 +++++++++++++++++++++++++--------------- test/test_unbalanced.py | 90 +++++++++++++--- 2 files changed, 221 insertions(+), 98 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 265006d2c..6c30a88bb 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -19,7 +19,8 @@ from .utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, + method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -39,7 +40,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -66,8 +67,17 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors).s method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations @@ -134,30 +144,30 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - numItermax=1000, stopThr=1e-6, verbose=False, - log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, + method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -175,7 +185,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -202,8 +212,17 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations @@ -263,29 +282,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', b = b[:, None] if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -304,7 +324,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -330,6 +350,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -404,15 +433,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, 1), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - K = nx.exp(M / (-reg)) + if reg_type == "kl": + K = nx.exp(-M / reg) * a[:, None] * b[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -474,9 +509,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", + warmstart=None, tau=1e5, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -496,7 +532,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence @@ -523,6 +559,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). tau : float threshold for max value in u or v for log scaling numItermax : int, optional @@ -597,16 +642,21 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - a = a.reshape(dim_a, 1) + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - # print(reg) - K = nx.exp(-M / reg) + if reg_type == "kl": + K = nx.exp(-M / reg) * a[:, None] * b[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -1074,7 +1124,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, raise ValueError("Unknown method '%s'." % method) -def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1084,7 +1134,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1094,6 +1144,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1113,8 +1164,11 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) - Entropy regularization term >= 0. + Regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1172,36 +1226,33 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b - if G0 is None: - G = a[:, None] * b[None, :] - else: - G = G0 + G = a[:, None] * b[None, :] if G0 is None else G0 + c = a[:, None] * b[None, :] if c is None else c reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': [], 'G': []} - if div == 'kl': - sum_r = reg + reg_m1 + reg_m2 - r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) - elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * a[:, None] * b[None, :] - M - K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) - else: + if div not in ["kl", "l2"]: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' + + if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = a[:, None]**(r1 + r) * b[None, :]**(r2 + r) * nx.exp(- M / sum_r) + K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) + elif div == 'l2': + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) for i in range(numItermax): Gprev = G if div == 'kl': - G = K * G**(r1 + r2) / (nx.sum(G, 1, keepdims=True)**r1 * nx.sum(G, 0, keepdims=True)**r2 + 1e-16) + Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 + G = K * G**(r1 + r2) / Gd elif div == 'l2': Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 @@ -1223,7 +1274,7 @@ def mm_unbalanced(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return G -def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, +def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. @@ -1233,7 +1284,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 @@ -1243,6 +1294,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- @@ -1264,6 +1316,9 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1307,7 +1362,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, reg=reg, div=div, G0=G0, + _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) @@ -1317,7 +1372,7 @@ def mm_unbalanced2(a, b, M, reg_m, reg=0, div='kl', G0=None, numItermax=1000, return log_mm['cost'] -def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT @@ -1326,25 +1381,25 @@ def _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='k m, n = M.shape def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - p.sum() + q.sum() + return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) def reg_l2(G): - return np.sum((G - a[:, None] * b[None, :])**2) / 2 + return np.sum((G - c)**2) / 2 def grad_l2(G): - return G - a[:, None] * b[None, :] + return G - c def reg_kl(G): - return kl(G, a[:, None] * b[None, :]) + return kl(G, c) def grad_kl(G): - return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + return np.log(G / c + 1e-16) def reg_entropy(G): - return np.sum(G * np.log(G + 1e-16)) + return np.sum(G * np.log(G + 1e-16)) - np.sum(G) def grad_entropy(G): - return np.log(G + 1e-16) + 1 + return np.log(G + 1e-16) if reg_div == 'kl': reg_fun = reg_kl @@ -1392,7 +1447,7 @@ def _func(G): return _func -def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, +def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. @@ -1400,7 +1455,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{a} \mathbf{b}^T) + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -1412,6 +1467,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize @@ -1426,6 +1482,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, loss matrix reg: float regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, @@ -1433,7 +1492,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic). regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1480,21 +1540,18 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) - + if c is None: + c = a[:, None] * b[None, :] + M, a, b, c = list_to_array(M, a, b, c) + nx = get_backend(M, a, b, c) M0 = M + # convert to numpy - a, b, M = nx.to_numpy(a, b, M) + a, b, c, M = nx.to_numpy(a, b, c, M) + G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) reg_m1, reg_m2 = get_parameter_pair(reg_m) - - if G0 is not None: - G0 = nx.to_numpy(G0) - else: - G0 = np.zeros(M.shape) - - _func = _get_loss_unbalanced(a, b, M, reg, reg_m1, reg_m2, reg_div, regm_div) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 272794cb8..a5dceede9 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,8 +14,9 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -25,29 +26,32 @@ def test_unbalanced_convergence(nx, method): # make dists unbalanced b = ot.utils.unif(n) * 1.5 - M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + epsilon = 1. reg_m = 1. - a, b, M = nx.from_numpy(a, b, M) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, + reg_m=reg_m,reg_type=reg_type, method=method, log=True, verbose=True) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method=method, verbose=True + a, b, M, epsilon, reg_m, reg_type=reg_type, method=method, verbose=True )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - + if reg_type == "entropy": + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + elif reg_type == "kl": + log_ab = loga[:, None] + logb[None, :] + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -65,10 +69,12 @@ def test_unbalanced_convergence(nx, method): a, b = nx.from_numpy(a_np, b_np) G = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, + reg_type=reg_type, method=method, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( - a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, + reg_type=reg_type, method=method, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) @@ -394,6 +400,31 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + @pytest.mark.parametrize("div", ["kl", "l2"]) def test_mm_convergence(nx, div): n = 100 @@ -483,6 +514,41 @@ def test_mm_relaxation_parameters(nx, div): np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + def test_mm_wrong_divergence(nx): n = 100 From 09120267f6d81645fb4d7f26aaa00cd07fb9e92f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 31 Oct 2023 23:32:23 +0100 Subject: [PATCH 03/18] fix bug in test --- test/test_unbalanced.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index a5dceede9..ce72c971b 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -33,12 +33,13 @@ def test_unbalanced_convergence(nx, method, reg_type): reg_m = 1. G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m,reg_type=reg_type, + reg_m=reg_m, reg_type=reg_type, method=method, log=True, verbose=True) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, reg_type=reg_type, method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type, + method=method, verbose=True )) # check fixed point equations # in log-domain From 4fa2195be2ec13b6b9db9c779835ff2096ab41ed Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 31 Oct 2023 23:44:44 +0100 Subject: [PATCH 04/18] remove stab_sinkhorn --- test/test_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index ce72c971b..b0f38c95a 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,7 +14,7 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn"], ["kl", "entropy"])) # @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT From a288be060d45c165ae04cb7badb5f9dc0fa8a13d Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 31 Oct 2023 23:50:01 +0100 Subject: [PATCH 05/18] remove kl --- test/test_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b0f38c95a..666004435 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,7 +14,7 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["entropy"])) # @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT From e5c3f84ee8515cb40ac9e4b0aa057463dc2a7cf7 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 1 Nov 2023 00:11:06 +0100 Subject: [PATCH 06/18] fix bug in lbfgsb_unbalanced --- ot/unbalanced.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 6c30a88bb..5b05a2592 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -1540,8 +1540,6 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - if c is None: - c = a[:, None] * b[None, :] M, a, b, c = list_to_array(M, a, b, c) nx = get_backend(M, a, b, c) M0 = M @@ -1549,6 +1547,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', # convert to numpy a, b, c, M = nx.to_numpy(a, b, c, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) + c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) reg_m1, reg_m2 = get_parameter_pair(reg_m) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) From 3787cdc5363c903ffe07ec69de384df6abe3edd6 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 1 Nov 2023 00:30:58 +0100 Subject: [PATCH 07/18] fix bug in lbfgsb_unbalanced --- ot/unbalanced.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 5b05a2592..14beba931 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -445,7 +445,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "kl": - K = nx.exp(-M / reg) * a[:, None] * b[None, :] + K = nx.exp(-M / reg) * a.squeeze()[:, None] * b.squeeze()[None, :] elif reg_type == "entropy": K = nx.exp(-M / reg) @@ -654,7 +654,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "kl": - K = nx.exp(-M / reg) * a[:, None] * b[None, :] + K = nx.exp(-M / reg) * a.squeeze()[:, None] * b.squeeze()[None, :] elif reg_type == "entropy": K = nx.exp(-M / reg) @@ -1540,12 +1540,12 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - M, a, b, c = list_to_array(M, a, b, c) - nx = get_backend(M, a, b, c) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) M0 = M # convert to numpy - a, b, c, M = nx.to_numpy(a, b, c, M) + a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) From 810dfe18d26adc68faa3b640693cda3e0e6da52a Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 1 Nov 2023 01:31:55 +0100 Subject: [PATCH 08/18] fix bug in KL in sinkhorn_unbalanced --- ot/unbalanced.py | 15 +++++++++------ test/test_unbalanced.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 14beba931..a49a093ec 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -654,9 +654,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "kl": - K = nx.exp(-M / reg) * a.squeeze()[:, None] * b.squeeze()[None, :] - elif reg_type == "entropy": - K = nx.exp(-M / reg) + log_ab = nx.log(a + 1e-16).squeeze()[:, None] + nx.log(b + 1e-16).squeeze()[None, :] + M0 = M - reg * log_ab + else: + M0 = M + + K = nx.exp(-M0 / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -691,7 +694,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", else: alpha = alpha + reg * nx.log(nx.max(u)) beta = beta + reg * nx.log(nx.max(v)) - K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) @@ -737,7 +740,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", nx.log(M + 1e-100)[:, :, None] + logu[:, None, :] + logv[None, :, :] - - M[:, :, None] / reg, + - M0[:, :, None] / reg, axis=(0, 1) ) res = nx.exp(res) @@ -747,7 +750,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", return res else: # return OT matrix - ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: return ot_matrix, log else: diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 666004435..ce72c971b 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,7 +14,7 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) # @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT From 48b6729034bdea5a76b9b7d760f5dbfef0611b4b Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 1 Nov 2023 09:20:24 +0100 Subject: [PATCH 09/18] edit release.md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index f943886d5..b79f2e165 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,6 +13,7 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) ++ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) From 2ad05abe78814c0dce0daf7d6ba0b14f0884de73 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 1 Nov 2023 09:24:33 +0100 Subject: [PATCH 10/18] fix test --- test/test_unbalanced.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index c6c7322f7..be13674a8 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -32,22 +32,13 @@ def test_unbalanced_convergence(nx, method, reg_type): epsilon = 1. reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, -<<<<<<< HEAD - reg_m=reg_m, reg_type=reg_type, -======= - reg_m=reg_m,reg_type=reg_type, ->>>>>>> e903625db204822f642d8fd4077001568c53eeaf - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type, + method=method, log=True, verbose=True + ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( -<<<<<<< HEAD a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type, method=method, verbose=True -======= - a, b, M, epsilon, reg_m, reg_type=reg_type, method=method, verbose=True ->>>>>>> e903625db204822f642d8fd4077001568c53eeaf )) # check fixed point equations # in log-domain From 2ccb2aa9f470bdbeb23d14f407bff8e1c5739a16 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 09:39:44 +0100 Subject: [PATCH 11/18] add test and rearrange arguments --- ot/unbalanced.py | 38 +++++++++++++-------------- test/test_unbalanced.py | 58 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index a58fcebd4..85a58c134 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -19,8 +19,8 @@ from .utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, - method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -67,6 +67,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameters reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) @@ -75,10 +78,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors).s - method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters + (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -165,8 +165,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, - method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and @@ -212,6 +212,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameterss reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) @@ -221,9 +224,6 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). - method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -435,12 +435,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", # distances if warmstart is None: if n_hists: - u = nx.ones((dim_a, 1), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + u = nx.ones((dim_a, 1), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) a = a.reshape(dim_a, 1) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) @@ -644,12 +644,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", # distances if warmstart is None: if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) a = a.reshape(dim_a, 1) else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index be13674a8..17d4380e8 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -33,12 +33,12 @@ def test_unbalanced_convergence(nx, method, reg_type): reg_m = 1. G, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type, - method=method, log=True, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=True, verbose=True ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, reg_type=reg_type, - method=method, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, verbose=True )) # check fixed point equations # in log-domain @@ -70,15 +70,61 @@ def test_unbalanced_convergence(nx, method, reg_type): G = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, - reg_type=reg_type, method=method, verbose=True + method=method, reg_type=reg_type, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, - reg_type=reg_type, method=method, verbose=True + method=method, reg_type=reg_type, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_warmstart(nx, method, reg_type): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + dim_a, dim_b = M.shape + warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + ) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, verbose=True + )) + + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, log=True, verbose=True + ) + loss0 = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=None, verbose=True + )) + + np.testing.assert_allclose(loss, loss0, atol=1e-6) + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + + @pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT From 51d6e24aac1b2588c0780f5eeea63c97e86abbfb Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 09:50:20 +0100 Subject: [PATCH 12/18] fix test --- test/test_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 17d4380e8..9b5b7f9eb 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -117,7 +117,7 @@ def test_unbalanced_warmstart(nx, method, reg_type): reg_type=reg_type, warmstart=None, verbose=True )) - np.testing.assert_allclose(loss, loss0, atol=1e-6) + np.testing.assert_allclose(loss, loss0, atol=1e-5) np.testing.assert_allclose( nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) np.testing.assert_allclose( From 60cd0d7db5408b0f57e50b13e393b795a41c0a9c Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 10:48:22 +0100 Subject: [PATCH 13/18] fix test --- test/test_unbalanced.py | 55 +++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 9b5b7f9eb..96dc6158d 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -88,41 +88,58 @@ def test_unbalanced_warmstart(nx, method, reg_type): x = rng.randn(n, 2) a = ot.utils.unif(n) - - # make dists unbalanced - b = ot.utils.unif(n) * 1.5 + b = ot.utils.unif(n) M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) epsilon = 1. reg_m = 1. - dim_a, dim_b = M.shape - warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) - G, log = ot.unbalanced.sinkhorn_unbalanced( + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + reg_type=reg_type, warmstart=None, log=True, verbose=True ) - loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + loss0 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, verbose=True - )) + reg_type=reg_type, warmstart=None, verbose=True + ) - G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + # dim_a, dim_b = M.shape + # warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) + # G, log = ot.unbalanced.sinkhorn_unbalanced( + # a, b, M, reg=epsilon, reg_m=reg_m, method=method, + # reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + # ) + # loss = ot.unbalanced.sinkhorn_unbalanced2( + # a, b, M, reg=epsilon, reg_m=reg_m, method=method, + # reg_type=reg_type, warmstart=warmstart, verbose=True + # ) + + _, log = ot.lp.emd(a, b, M, log=True) + warmstart1 = (log["u"], log["v"]) + G1, log1 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, log=True, verbose=True + reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True ) - loss0 = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + loss1 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, verbose=True - )) + reg_type=reg_type, warmstart=warmstart1, verbose=True + ) + + # np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) - np.testing.assert_allclose(loss, loss0, atol=1e-5) + # np.testing.assert_allclose( + # nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + # np.testing.assert_allclose( + # nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05) np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) - np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05) + + # np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) @pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) From c0d8391cbd589a88a1fa2fe59dbbec6a93726646 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 10:55:28 +0100 Subject: [PATCH 14/18] fix test --- ot/unbalanced.py | 4 ++++ test/test_unbalanced.py | 32 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 85a58c134..37e804a09 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -443,6 +443,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + if not n_hists: + u, v = u.reshape(-1), v.reshape(-1) if reg_type == "kl": K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] @@ -652,6 +654,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + if not n_hists: + u, v = u.reshape(-1), v.reshape(-1) if reg_type == "kl": log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 96dc6158d..70f9f15f9 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -104,16 +104,16 @@ def test_unbalanced_warmstart(nx, method, reg_type): reg_type=reg_type, warmstart=None, verbose=True ) - # dim_a, dim_b = M.shape - # warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) - # G, log = ot.unbalanced.sinkhorn_unbalanced( - # a, b, M, reg=epsilon, reg_m=reg_m, method=method, - # reg_type=reg_type, warmstart=warmstart, log=True, verbose=True - # ) - # loss = ot.unbalanced.sinkhorn_unbalanced2( - # a, b, M, reg=epsilon, reg_m=reg_m, method=method, - # reg_type=reg_type, warmstart=warmstart, verbose=True - # ) + dim_a, dim_b = M.shape + warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + ) + loss = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, warmstart=warmstart, verbose=True + ) _, log = ot.lp.emd(a, b, M, log=True) warmstart1 = (log["u"], log["v"]) @@ -126,19 +126,19 @@ def test_unbalanced_warmstart(nx, method, reg_type): reg_type=reg_type, warmstart=warmstart1, verbose=True ) - # np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) - # np.testing.assert_allclose( - # nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) - # np.testing.assert_allclose( - # nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) np.testing.assert_allclose( nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05) np.testing.assert_allclose( nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05) - # np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) From 4db194a865ae8363240f08094a5ab7d330da4989 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 13:33:13 +0100 Subject: [PATCH 15/18] fix bug in test --- ot/unbalanced.py | 69 +++++++++++++++++++++++++++-------------- test/test_unbalanced.py | 12 +++---- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 37e804a09..56387aa10 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -277,30 +277,55 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, - warmstart, numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: - raise ValueError('Unknown method %s.' % method) + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", @@ -443,8 +468,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - if not n_hists: - u, v = u.reshape(-1), v.reshape(-1) if reg_type == "kl": K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] @@ -654,8 +677,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - if not n_hists: - u, v = u.reshape(-1), v.reshape(-1) if reg_type == "kl": log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 70f9f15f9..0884dc0ef 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -15,7 +15,6 @@ @pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) -# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -80,7 +79,6 @@ def test_unbalanced_convergence(nx, method, reg_type): @pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) -# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_warmstart(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -115,8 +113,8 @@ def test_unbalanced_warmstart(nx, method, reg_type): reg_type=reg_type, warmstart=warmstart, verbose=True ) - _, log = ot.lp.emd(a, b, M, log=True) - warmstart1 = (log["u"], log["v"]) + _, log_emd = ot.lp.emd(a, b, M, log=True) + warmstart1 = (log_emd["u"], log_emd["v"]) G1, log1 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True @@ -126,9 +124,6 @@ def test_unbalanced_warmstart(nx, method, reg_type): reg_type=reg_type, warmstart=warmstart1, verbose=True ) - np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) - np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) - np.testing.assert_allclose( nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) np.testing.assert_allclose( @@ -141,6 +136,9 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) + @pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): From e8298d05bba5fb16eb844bdb88f4d95b5f73e590 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 14:00:48 +0100 Subject: [PATCH 16/18] fix bug in doctest --- ot/unbalanced.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 56387aa10..8ea126a94 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -110,9 +110,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-unbalanced: References @@ -249,8 +248,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - array([0.31912866]) - + 0.3191285827553562 .. _references-sinkhorn-unbalanced2: References @@ -415,9 +413,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-knopp-unbalanced: References @@ -625,9 +622,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) - + array([[0.51122814, 0.18807032], + [0.18807032, 0.51122814]]) .. _references-sinkhorn-stabilized-unbalanced: References From 1bccd478fc89e1a72b5b6b478051906b98370d0f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 14:41:15 +0100 Subject: [PATCH 17/18] fix bug in doctest --- ot/unbalanced.py | 5 +++-- test/test_unbalanced.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 8ea126a94..73667b324 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -244,11 +244,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', -------- >>> import ot + >>> import numpy as np >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - 0.3191285827553562 + >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) + 0.31912858 .. _references-sinkhorn-unbalanced2: References diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 0884dc0ef..6b23e3060 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,7 +14,7 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -78,7 +78,7 @@ def test_unbalanced_convergence(nx, method, reg_type): np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) def test_unbalanced_warmstart(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -140,7 +140,7 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) -@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) +@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT n = 100 @@ -184,7 +184,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 From c2c0e96949ed9a0469269563deb4f0420d957be7 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 2 Nov 2023 15:17:03 +0100 Subject: [PATCH 18/18] add test for more coverage --- test/test_unbalanced.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 6b23e3060..7007e336b 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -140,6 +140,36 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) +@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) +def test_sinkhorn_unbalanced2(nx, method, reg_type, log): + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=False, verbose=True + )) + + res = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=log, verbose=True + ) + loss0 = res[0] if log else res + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + + @pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")])) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT @@ -202,11 +232,10 @@ def test_unbalanced_multiple_inputs(nx, method): a, b, M = nx.from_numpy(a, b, M) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method=method, - log=True, - verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, method=method, + log=True, verbose=True) + # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon)