From f33062a62efaf9748ef9ee843e7bf40676e05a8d Mon Sep 17 00:00:00 2001 From: fegounna Date: Mon, 3 Feb 2025 19:40:01 +0100 Subject: [PATCH 1/8] implementation of low rank ot via factor relaxation paper --- ot/low_rank/__init__.py | 12 ++ ot/low_rank/_factor_relaxation.py | 217 ++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 ot/low_rank/__init__.py create mode 100644 ot/low_rank/_factor_relaxation.py diff --git a/ot/low_rank/__init__.py b/ot/low_rank/__init__.py new file mode 100644 index 000000000..3d5351393 --- /dev/null +++ b/ot/low_rank/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +""" +Low rank Solvers +""" + +# Author: Yessin Moakher +# +# License: MIT License + +from ._factor_relaxation import solve_balanced_FRLC + +__all__ = ["solve_balanced_FRLC"] diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py new file mode 100644 index 000000000..717387072 --- /dev/null +++ b/ot/low_rank/_factor_relaxation.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +""" +Low rank Solvers +""" + +# Author: Yessin Moakher +# +# License: MIT License + +from ..utils import list_to_array +from ..backend import get_backend +from ..bregman import sinkhorn +from ..unbalanced import sinkhorn_unbalanced + + +def _initialize_couplings(a, b, r, nx, reg_init=1, random_state=42): + """Initialize the couplings Q, R, T for the Factor Relaxation algorithm.""" + + n = a.shape[0] + m = b.shape[0] + + nx.seed(seed=random_state) + M_Q = nx.rand(n, r, type_as=a) + M_R = nx.rand(m, r, type_as=a) + M_T = nx.rand(r, r, type_as=a) + + g_Q, g_R = ( + nx.full(r, 1 / r, type_as=a), + nx.full(r, 1 / r, type_as=a), + ) # Shape (r,) and (r,) + + Q = sinkhorn(a, g_Q, M_Q, reg_init, method="sinkhorn_log") + R = sinkhorn(b, g_R, M_R, reg_init, method="sinkhorn_log") + T = sinkhorn( + nx.dot(Q.T, nx.ones(n, type_as=a)), + nx.dot(R.T, nx.ones(m, type_as=a)), + M_T, + reg_init, + method="sinkhorn_log", + ) + + return Q, R, T + + +def _compute_gradient_Q(M, Q, R, X, g_Q, nx): + """Compute the gradient of the loss with respect to Q.""" + + n = Q.shape[0] + + term1 = nx.dot( + nx.dot(M, R), X.T + ) # The order of multiplications is important because r<0) + log : bool, optional + Print cost value at each iteration. + + Returns + ------- + P : array-like, shape (n, m) + The computed low-rank optimal transportion matrix. + + References + ---------- + [1] Halmos, P., Liu, X., Gold, J., & Raphael, B. (2024). Low-Rank Optimal Transport through Factor Relaxation with Latent Coupling. + In Proceedings of the Thirty-eighth Annual Conference on Neural Information Processing Systems (NeurIPS 2024). + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + n, m = M.shape + + ones_n, ones_m = ( + nx.ones(n, type_as=M), + nx.ones(m, type_as=M), + ) # Shape (n,) and (m,) + + Q, R, T = _initialize_couplings(a, b, r, nx) # Shape (n,r), (m,r), (r,r) + g_Q, g_R = nx.dot(Q.T, ones_n), nx.dot(R.T, ones_m) # Shape (r,) and (r,) + X = nx.dot(nx.dot(nx.diag(1 / g_Q), T), nx.diag(1 / g_R)) # Shape (r,r) + + for i in range(numItermax): + grad_Q = _compute_gradient_Q(M, Q, R, X, g_Q, nx) # Shape (n,r) + grad_R = _compute_gradient_R(M, Q, R, X, g_R, nx) # Shape (m,r) + + gamma_k = gamma / max( + nx.max(nx.abs(grad_Q)), nx.max(nx.abs(grad_R)) + ) # l-inf normalization + + # We can parallelize the calculation of Q_new and R_new + Q_new = sinkhorn_unbalanced( + a=a, + b=g_Q, + M=grad_Q, + reg=1 / gamma_k, + reg_m=[float("inf"), tau], + method="sinkhorn_stabilized", + ) + + R_new = sinkhorn_unbalanced( + a=b, + b=g_R, + M=grad_R, + reg=1 / gamma_k, + reg_m=[float("inf"), tau], + method="sinkhorn_stabilized", + ) + + g_Q = nx.dot(Q_new.T, ones_n) + g_R = nx.dot(R_new.T, ones_m) + + grad_T = _compute_gradient_T(Q_new, R_new, M, g_Q, g_R, nx) # Shape (r, r) + + gamma_T = gamma / nx.max(nx.abs(grad_T)) + + T_new = sinkhorn( + g_R, g_Q, grad_T, reg=1 / gamma_T, method="sinkhorn_log" + ) # Shape (r, r) + + X_new = nx.dot(nx.dot(nx.diag(1 / g_Q), T_new), nx.diag(1 / g_R)) # Shape (r,r) + + if log: + print(f"iteration {i} ", nx.sum(M * nx.dot(nx.dot(Q_new, X_new), R_new.T))) + + if ( + _compute_distance(Q_new, R_new, T_new, Q, R, T, nx) + < gamma_k * gamma_k * stopThr + ): + return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m) + + Q, R, T, X = Q_new, R_new, T_new, X_new From 070ee3c0b98f3c7d4fc7bda0bbb63e21ba58073e Mon Sep 17 00:00:00 2001 From: fegounna Date: Mon, 3 Feb 2025 19:52:33 +0100 Subject: [PATCH 2/8] implementation of low rank ot via factor relaxation paper --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 745a7de67..a66a125f7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,7 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Fix documentation in the module `ot.gaussian` (PR #718) +- Implement low rank through Factor Relaxation with Latent Coupling #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From cdde716262aeed9afafc71c2117ba4c03e8afc4c Mon Sep 17 00:00:00 2001 From: fegounna Date: Mon, 3 Feb 2025 20:03:37 +0100 Subject: [PATCH 3/8] implementation of low rank ot via factor relaxation paper --- ot/low_rank/_factor_relaxation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py index 717387072..cded9aa51 100644 --- a/ot/low_rank/_factor_relaxation.py +++ b/ot/low_rank/_factor_relaxation.py @@ -51,7 +51,7 @@ def _compute_gradient_Q(M, Q, R, X, g_Q, nx): nx.dot(M, R), X.T ) # The order of multiplications is important because r< Date: Mon, 3 Feb 2025 20:06:13 +0100 Subject: [PATCH 4/8] implementation of low rank ot via factor relaxation paper --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index a66a125f7..9487f4ba6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,7 +8,7 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Fix documentation in the module `ot.gaussian` (PR #718) -- Implement low rank through Factor Relaxation with Latent Coupling +- Implement low rank through Factor Relaxation with Latent Coupling (PR #719) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From c85bf21efd054daaab4c9b9f29421db14a505181 Mon Sep 17 00:00:00 2001 From: fegounna Date: Tue, 4 Feb 2025 12:06:47 +0100 Subject: [PATCH 5/8] add the definition of r to the description --- ot/low_rank/_factor_relaxation.py | 45 ++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py index cded9aa51..789a0579e 100644 --- a/ot/low_rank/_factor_relaxation.py +++ b/ot/low_rank/_factor_relaxation.py @@ -105,11 +105,11 @@ def solve_balanced_FRLC( .. math:: \textbf{P} = \mathop{\arg \min}_P \quad \langle \textbf{P}, \mathbf{M} \rangle_F - \text{s.t.} \ \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T + \text{s.t.} \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T - \textbf{Q} &\in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R} + \textbf{Q} \in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R} - \textbf{Q}, \textbf{R}, \textbf{T} &\geq 0 + \textbf{Q} \in \mathbb{R}^+_{n,r},\textbf{R} \in \mathbb{R}^+_{m,r},\textbf{T} \in \mathbb{R}^+_{r,r} where: @@ -215,3 +215,42 @@ def solve_balanced_FRLC( return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m) Q, R, T, X = Q_new, R_new, T_new, X_new + + +if __name__ == "__main__": + import torch + + grid_size = 4 + torch.manual_seed(42) + x_vals = torch.linspace(0, 3, grid_size) + y_vals = torch.linspace(0, 3, grid_size) + X, Y = torch.meshgrid(x_vals, y_vals, indexing="ij") + source_points = torch.stack([X.ravel(), Y.ravel()], dim=-1) # (16, 2) + a = torch.ones(len(source_points)) / len(source_points) # Uniform distribution + + # Generate Target Distribution (Gaussian Samples) + mean = torch.tensor([2.0, 2.0]) + cov = torch.tensor([[1.0, 0.5], [0.5, 1.0]]) + target_points = torch.distributions.MultivariateNormal( + mean, covariance_matrix=cov + ).sample((len(source_points),)) # (16, 2) + b = torch.ones(len(target_points)) / len(target_points) # Uniform distribution + + # Compute Cost Matrix (Squared Euclidean Distance) + C = torch.cdist(source_points, target_points, p=2) ** 2 + + # Solve OT problem (assuming you have PyTorch versions of these functions) + print(type(a.numpy())) + P = solve_balanced_FRLC( + a.to(torch.float64), + b.to(torch.float64), + C.to(torch.float64), + 10, + tau=1e2, + gamma=1e2, + stopThr=1e-7, + numItermax=100, + log=True, + ) + P = sinkhorn(a, b, C, reg=1) + print(torch.sum(P * C)) From ba515c2a3aee9ca2234e629d53c22cac9a478502 Mon Sep 17 00:00:00 2001 From: fegounna Date: Tue, 4 Feb 2025 12:12:13 +0100 Subject: [PATCH 6/8] update default value --- ot/low_rank/_factor_relaxation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py index 789a0579e..4bb6a8882 100644 --- a/ot/low_rank/_factor_relaxation.py +++ b/ot/low_rank/_factor_relaxation.py @@ -92,7 +92,7 @@ def solve_balanced_FRLC( r, tau, gamma, - stopThr=1e-5, + stopThr=1e-7, numItermax=1000, log=False, ): From 3168387e063cb2dbd3d3d52f6f2e6dac3f254c5a Mon Sep 17 00:00:00 2001 From: fegounna Date: Tue, 4 Feb 2025 16:49:24 +0100 Subject: [PATCH 7/8] update default value --- ot/low_rank/_factor_relaxation.py | 39 ------------------------------- 1 file changed, 39 deletions(-) diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py index 4bb6a8882..ceafb0e6f 100644 --- a/ot/low_rank/_factor_relaxation.py +++ b/ot/low_rank/_factor_relaxation.py @@ -215,42 +215,3 @@ def solve_balanced_FRLC( return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m) Q, R, T, X = Q_new, R_new, T_new, X_new - - -if __name__ == "__main__": - import torch - - grid_size = 4 - torch.manual_seed(42) - x_vals = torch.linspace(0, 3, grid_size) - y_vals = torch.linspace(0, 3, grid_size) - X, Y = torch.meshgrid(x_vals, y_vals, indexing="ij") - source_points = torch.stack([X.ravel(), Y.ravel()], dim=-1) # (16, 2) - a = torch.ones(len(source_points)) / len(source_points) # Uniform distribution - - # Generate Target Distribution (Gaussian Samples) - mean = torch.tensor([2.0, 2.0]) - cov = torch.tensor([[1.0, 0.5], [0.5, 1.0]]) - target_points = torch.distributions.MultivariateNormal( - mean, covariance_matrix=cov - ).sample((len(source_points),)) # (16, 2) - b = torch.ones(len(target_points)) / len(target_points) # Uniform distribution - - # Compute Cost Matrix (Squared Euclidean Distance) - C = torch.cdist(source_points, target_points, p=2) ** 2 - - # Solve OT problem (assuming you have PyTorch versions of these functions) - print(type(a.numpy())) - P = solve_balanced_FRLC( - a.to(torch.float64), - b.to(torch.float64), - C.to(torch.float64), - 10, - tau=1e2, - gamma=1e2, - stopThr=1e-7, - numItermax=100, - log=True, - ) - P = sinkhorn(a, b, C, reg=1) - print(torch.sum(P * C)) From 4a44da2df3eee63f7654507db675253f74ccd0b2 Mon Sep 17 00:00:00 2001 From: fegounna Date: Tue, 11 Mar 2025 01:41:13 +0100 Subject: [PATCH 8/8] fix logical bug --- ot/low_rank/_factor_relaxation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ot/low_rank/_factor_relaxation.py b/ot/low_rank/_factor_relaxation.py index ceafb0e6f..3560c5784 100644 --- a/ot/low_rank/_factor_relaxation.py +++ b/ot/low_rank/_factor_relaxation.py @@ -178,6 +178,7 @@ def solve_balanced_FRLC( a=a, b=g_Q, M=grad_Q, + c=Q, reg=1 / gamma_k, reg_m=[float("inf"), tau], method="sinkhorn_stabilized", @@ -187,6 +188,7 @@ def solve_balanced_FRLC( a=b, b=g_R, M=grad_R, + c=R, reg=1 / gamma_k, reg_m=[float("inf"), tau], method="sinkhorn_stabilized", @@ -199,8 +201,14 @@ def solve_balanced_FRLC( gamma_T = gamma / nx.max(nx.abs(grad_T)) - T_new = sinkhorn( - g_R, g_Q, grad_T, reg=1 / gamma_T, method="sinkhorn_log" + T_new = sinkhorn_unbalanced( + M=grad_T, + a=g_Q, + b=g_R, + reg=1 / gamma_T, + c=T, + reg_m=[float("inf"), float("inf")], + method="sinkhorn_stabilized", ) # Shape (r, r) X_new = nx.dot(nx.dot(nx.diag(1 / g_Q), T_new), nx.diag(1 / g_R)) # Shape (r,r)