|
| 1 | +""" |
| 2 | +Low rank OT solvers |
| 3 | +""" |
| 4 | + |
| 5 | +# Author: Laurène David <laurene.david@ip-paris.fr> |
| 6 | +# |
| 7 | +# License: MIT License |
| 8 | + |
| 9 | + |
| 10 | +import warnings |
| 11 | +from .utils import unif, get_lowrank_lazytensor |
| 12 | +from .backend import get_backend |
| 13 | + |
| 14 | + |
| 15 | +def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): |
| 16 | + """ |
| 17 | + Compute the low rank decomposition of a squared euclidean distance matrix. |
| 18 | + This function won't work for any other distance metric. |
| 19 | +
|
| 20 | + See "Section 3.5, proposition 1" |
| 21 | +
|
| 22 | + Parameters |
| 23 | + ---------- |
| 24 | + X_s : array-like, shape (n_samples_a, dim) |
| 25 | + samples in the source domain |
| 26 | + X_t : array-like, shape (n_samples_b, dim) |
| 27 | + samples in the target domain |
| 28 | + nx : POT backend, default none |
| 29 | +
|
| 30 | +
|
| 31 | + Returns |
| 32 | + ---------- |
| 33 | + M1 : array-like, shape (n_samples_a, dim+2) |
| 34 | + First low rank decomposition of the distance matrix |
| 35 | + M2 : array-like, shape (n_samples_b, dim+2) |
| 36 | + Second low rank decomposition of the distance matrix |
| 37 | +
|
| 38 | +
|
| 39 | + References |
| 40 | + ---------- |
| 41 | + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). |
| 42 | + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. |
| 43 | + """ |
| 44 | + |
| 45 | + if nx is None: |
| 46 | + nx = get_backend(X_s, X_t) |
| 47 | + |
| 48 | + ns = X_s.shape[0] |
| 49 | + nt = X_t.shape[0] |
| 50 | + |
| 51 | + # First low rank decomposition of the cost matrix (A) |
| 52 | + array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) |
| 53 | + array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1)) |
| 54 | + M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) |
| 55 | + |
| 56 | + # Second low rank decomposition of the cost matrix (B) |
| 57 | + array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1)) |
| 58 | + array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) |
| 59 | + M2 = nx.concatenate((array1, array2, X_t), axis=1) |
| 60 | + |
| 61 | + return M1, M2 |
| 62 | + |
| 63 | + |
| 64 | +def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): |
| 65 | + """ |
| 66 | + Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. |
| 67 | + This function is specific to lowrank_sinkhorn. |
| 68 | +
|
| 69 | + Parameters |
| 70 | + ---------- |
| 71 | + eps1 : array-like, shape (n_samples_a, r) |
| 72 | + First input parameter of the Dykstra algorithm |
| 73 | + eps2 : array-like, shape (n_samples_b, r) |
| 74 | + Second input parameter of the Dykstra algorithm |
| 75 | + eps3 : array-like, shape (r,) |
| 76 | + Third input parameter of the Dykstra algorithm |
| 77 | + p1 : array-like, shape (n_samples_a,) |
| 78 | + Samples weights in the source domain (same as "a" in lowrank_sinkhorn) |
| 79 | + p2 : array-like, shape (n_samples_b,) |
| 80 | + Samples weights in the target domain (same as "b" in lowrank_sinkhorn) |
| 81 | + alpha: int |
| 82 | + Lower bound for the weight vector g (same as "alpha" in lowrank_sinkhorn) |
| 83 | + stopThr : float |
| 84 | + Stop threshold on error |
| 85 | + numItermax : int |
| 86 | + Max number of iterations |
| 87 | + warn : bool, optional |
| 88 | + if True, raises a warning if the algorithm doesn't convergence. |
| 89 | + nx : default None |
| 90 | + POT backend |
| 91 | +
|
| 92 | +
|
| 93 | + Returns |
| 94 | + ---------- |
| 95 | + Q : array-like, shape (n_samples_a, r) |
| 96 | + Dykstra update of the first low-rank matrix decomposition Q |
| 97 | + R: array-like, shape (n_samples_b, r) |
| 98 | + Dykstra update of the Second low-rank matrix decomposition R |
| 99 | + g : array-like, shape (r, ) |
| 100 | + Dykstra update of the weight vector g |
| 101 | +
|
| 102 | +
|
| 103 | + References |
| 104 | + ---------- |
| 105 | + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). |
| 106 | + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. |
| 107 | +
|
| 108 | + """ |
| 109 | + |
| 110 | + # POT backend if None |
| 111 | + if nx is None: |
| 112 | + nx = get_backend(eps1, eps2, eps3, p1, p2) |
| 113 | + |
| 114 | + # ----------------- Initialisation of Dykstra algorithm ----------------- |
| 115 | + r = len(eps3) # rank |
| 116 | + g_ = nx.copy(eps3) # \tilde{g} |
| 117 | + q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 |
| 118 | + v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} |
| 119 | + q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} |
| 120 | + err = 1 # initial error |
| 121 | + |
| 122 | + # --------------------- Dykstra algorithm ------------------------- |
| 123 | + |
| 124 | + # See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper |
| 125 | + |
| 126 | + for ii in range(numItermax): |
| 127 | + if err > stopThr: |
| 128 | + # Compute u^{(1)} and u^{(2)} |
| 129 | + u1 = p1 / nx.dot(eps1, v1_) |
| 130 | + u2 = p2 / nx.dot(eps2, v2_) |
| 131 | + |
| 132 | + # Compute g, g^{(3)}_1 and update \tilde{g} |
| 133 | + g = nx.maximum(alpha, g_ * q3_1) |
| 134 | + q3_1 = (g_ * q3_1) / g |
| 135 | + g_ = nx.copy(g) |
| 136 | + |
| 137 | + # Compute new value of g with \prod |
| 138 | + prod1 = (v1_ * q1) * nx.dot(eps1.T, u1) |
| 139 | + prod2 = (v2_ * q2) * nx.dot(eps2.T, u2) |
| 140 | + g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3) |
| 141 | + |
| 142 | + # Compute v^{(1)} and v^{(2)} |
| 143 | + v1 = g / nx.dot(eps1.T, u1) |
| 144 | + v2 = g / nx.dot(eps2.T, u2) |
| 145 | + |
| 146 | + # Compute q^{(1)}, q^{(2)} and q^{(3)}_2 |
| 147 | + q1 = (v1_ * q1) / v1 |
| 148 | + q2 = (v2_ * q2) / v2 |
| 149 | + q3_2 = (g_ * q3_2) / g |
| 150 | + |
| 151 | + # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} |
| 152 | + v1_, v2_ = nx.copy(v1), nx.copy(v2) |
| 153 | + g_ = nx.copy(g) |
| 154 | + |
| 155 | + # Compute error |
| 156 | + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) |
| 157 | + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) |
| 158 | + err = err1 + err2 |
| 159 | + |
| 160 | + else: |
| 161 | + break |
| 162 | + |
| 163 | + else: |
| 164 | + if warn: |
| 165 | + warnings.warn( |
| 166 | + "Sinkhorn did not converge. You might want to " |
| 167 | + "increase the number of iterations `numItermax` " |
| 168 | + ) |
| 169 | + |
| 170 | + # Compute low rank matrices Q, R |
| 171 | + Q = u1[:, None] * eps1 * v1[None, :] |
| 172 | + R = u2[:, None] * eps2 * v2[None, :] |
| 173 | + |
| 174 | + return Q, R, g |
| 175 | + |
| 176 | + |
| 177 | +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, |
| 178 | + numItermax=1000, stopThr=1e-9, warn=True, log=False): |
| 179 | + r""" |
| 180 | + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. |
| 181 | +
|
| 182 | + The function solves the following optimization problem: |
| 183 | +
|
| 184 | + .. math:: |
| 185 | + \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - |
| 186 | + \mathrm{reg} \cdot H((Q,R,g)) |
| 187 | +
|
| 188 | + where : |
| 189 | + - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix |
| 190 | + - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. |
| 191 | + - :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan |
| 192 | + - :math: `g` is the weight vector for the low-rank decomposition of the OT plan |
| 193 | + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) |
| 194 | + - :math: `r` is the rank of the OT plan |
| 195 | + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem |
| 196 | +
|
| 197 | +
|
| 198 | + Parameters |
| 199 | + ---------- |
| 200 | + X_s : array-like, shape (n_samples_a, dim) |
| 201 | + samples in the source domain |
| 202 | + X_t : array-like, shape (n_samples_b, dim) |
| 203 | + samples in the target domain |
| 204 | + a : array-like, shape (n_samples_a,) |
| 205 | + samples weights in the source domain |
| 206 | + b : array-like, shape (n_samples_b,) |
| 207 | + samples weights in the target domain |
| 208 | + reg : float, optional |
| 209 | + Regularization term >0 |
| 210 | + rank: int, optional. Default is None. (>0) |
| 211 | + Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. |
| 212 | + alpha: int, optional. Default is None. (>0 and <1/r) |
| 213 | + Lower bound for the weight vector g. If None, 1e-10 is considered |
| 214 | + numItermax : int, optional |
| 215 | + Max number of iterations |
| 216 | + stopThr : float, optional |
| 217 | + Stop threshold on error (>0) |
| 218 | + warn : bool, optional |
| 219 | + if True, raises a warning if the algorithm doesn't convergence. |
| 220 | + log : bool, optional |
| 221 | + record log if True |
| 222 | +
|
| 223 | +
|
| 224 | + Returns |
| 225 | + ------- |
| 226 | + lazy_plan : LazyTensor() |
| 227 | + OT plan in a LazyTensor object of shape (shape_plan) |
| 228 | + See :any:`LazyTensor` for more information. |
| 229 | + value : float |
| 230 | + Optimal value of the optimization problem |
| 231 | + value_linear : float |
| 232 | + Linear OT loss with the optimal OT |
| 233 | + Q : array-like, shape (n_samples_a, r) |
| 234 | + First low-rank matrix decomposition of the OT plan |
| 235 | + R: array-like, shape (n_samples_b, r) |
| 236 | + Second low-rank matrix decomposition of the OT plan |
| 237 | + g : array-like, shape (r, ) |
| 238 | + Weight vector for the low-rank decomposition of the OT plan |
| 239 | +
|
| 240 | +
|
| 241 | + References |
| 242 | + ---------- |
| 243 | + .. [65] Scetbon, M., Cuturi, M., & Peyré, G (2021). |
| 244 | + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. |
| 245 | +
|
| 246 | + """ |
| 247 | + |
| 248 | + # POT backend |
| 249 | + nx = get_backend(X_s, X_t) |
| 250 | + ns, nt = X_s.shape[0], X_t.shape[0] |
| 251 | + |
| 252 | + # Initialize weights a, b |
| 253 | + if a is None: |
| 254 | + a = unif(ns, type_as=X_s) |
| 255 | + if b is None: |
| 256 | + b = unif(nt, type_as=X_t) |
| 257 | + |
| 258 | + # Compute rank (see Section 3.1, def 1) |
| 259 | + r = rank |
| 260 | + if rank is None: |
| 261 | + r = min(ns, nt) |
| 262 | + |
| 263 | + if alpha is None: |
| 264 | + alpha = 1e-10 |
| 265 | + |
| 266 | + # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) |
| 267 | + # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) |
| 268 | + if 1 / r < alpha: |
| 269 | + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( |
| 270 | + a=alpha, r=1 / rank)) |
| 271 | + |
| 272 | + if r <= 0: |
| 273 | + raise ValueError("The rank parameter cannot have a negative value") |
| 274 | + |
| 275 | + # Low rank decomposition of the sqeuclidean cost matrix (A, B) |
| 276 | + M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) |
| 277 | + |
| 278 | + # Compute gamma (see "Section 3.4, proposition 4" in the paper) |
| 279 | + L = nx.sqrt( |
| 280 | + 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + |
| 281 | + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 |
| 282 | + ) |
| 283 | + gamma = 1 / (2 * L) |
| 284 | + |
| 285 | + # Initialize the low rank matrices Q, R, g |
| 286 | + Q = nx.ones((ns, r), type_as=a) |
| 287 | + R = nx.ones((nt, r), type_as=a) |
| 288 | + g = nx.ones(r, type_as=a) |
| 289 | + k = 100 |
| 290 | + |
| 291 | + # -------------------------- Low rank algorithm ------------------------------ |
| 292 | + # see "Section 3.3, Algorithm 3 LOT" in the paper |
| 293 | + |
| 294 | + for ii in range(k): |
| 295 | + # Compute the C*R dot matrix using the lr decomposition of C |
| 296 | + CR_ = nx.dot(M2.T, R) |
| 297 | + CR = nx.dot(M1, CR_) |
| 298 | + |
| 299 | + # Compute the C.t * Q dot matrix using the lr decomposition of C |
| 300 | + CQ_ = nx.dot(M1.T, Q) |
| 301 | + CQ = nx.dot(M2, CQ_) |
| 302 | + |
| 303 | + diag_g = (1 / g)[None, :] |
| 304 | + |
| 305 | + eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q)) |
| 306 | + eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R)) |
| 307 | + omega = nx.diag(nx.dot(Q.T, CR)) |
| 308 | + eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g)) |
| 309 | + |
| 310 | + Q, R, g = _LR_Dysktra( |
| 311 | + eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx |
| 312 | + ) |
| 313 | + Q = Q + 1e-16 |
| 314 | + R = R + 1e-16 |
| 315 | + |
| 316 | + # ----------------- Compute lazy_plan, value and value_linear ------------------ |
| 317 | + # see "Section 3.2: The Low-rank OT Problem" in the paper |
| 318 | + |
| 319 | + # Compute lazy plan (using LazyTensor class) |
| 320 | + lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) |
| 321 | + |
| 322 | + # Compute value_linear (using trace formula) |
| 323 | + v1 = nx.dot(Q.T, M1) |
| 324 | + v2 = nx.dot(R, (v1.T * diag_g).T) |
| 325 | + value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) |
| 326 | + |
| 327 | + # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) |
| 328 | + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q |
| 329 | + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g |
| 330 | + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R |
| 331 | + value = value_linear + reg * (reg_Q + reg_g + reg_R) |
| 332 | + |
| 333 | + if log: |
| 334 | + dict_log = dict() |
| 335 | + dict_log["value"] = value |
| 336 | + dict_log["value_linear"] = value_linear |
| 337 | + dict_log["lazy_plan"] = lazy_plan |
| 338 | + |
| 339 | + return Q, R, g, dict_log |
| 340 | + |
| 341 | + return Q, R, g |
0 commit comments