From f4ca25a68a206a00ffe9de6ad2cb8ae9d2b2878b Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 20 Aug 2023 18:28:21 +0200 Subject: [PATCH 01/10] Vectorize per-label computations for sinkhorn l1l2 --- ot/backend.py | 46 ++++++++++++++++++++++++---------------------- ot/da.py | 14 ++++++-------- test/test_da.py | 27 +++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7b2fe875f..3694ba350 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -407,7 +407,7 @@ def power(self, a, exponents): """ raise NotImplementedError() - def norm(self, a): + def norm(self, a, axis=None): r""" Computes the matrix frobenius norm. @@ -627,7 +627,7 @@ def diag(self, a, k=0): """ raise NotImplementedError() - def unique(self, a): + def unique(self, a, return_inverse=False): r""" Finds unique elements of given tensor. @@ -1087,8 +1087,8 @@ def sqrt(self, a): def power(self, a, exponents): return np.power(a, exponents) - def norm(self, a): - return np.sqrt(np.sum(np.square(a))) + def norm(self, a, axis=None): + return np.linalg.norm(a, axis=axis) def any(self, a): return np.any(a) @@ -1164,8 +1164,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return np.diag(a, k) - def unique(self, a): - return np.unique(a) + def unique(self, a, return_inverse=False): + return np.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): return special.logsumexp(a, axis=axis) @@ -1461,8 +1461,8 @@ def sqrt(self, a): def power(self, a, exponents): return jnp.power(a, exponents) - def norm(self, a): - return jnp.sqrt(jnp.sum(jnp.square(a))) + def norm(self, a, axis=None): + return jnp.linalg.norm(a, axis=axis) def any(self, a): return jnp.any(a) @@ -1535,8 +1535,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return jnp.diag(a, k) - def unique(self, a): - return jnp.unique(a) + def unique(self, a, return_inverse=False): + return jnp.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): return jspecial.logsumexp(a, axis=axis) @@ -1881,8 +1881,8 @@ def sqrt(self, a): def power(self, a, exponents): return torch.pow(a, exponents) - def norm(self, a): - return torch.sqrt(torch.sum(torch.square(a))) + def norm(self, a, axis=None): + return torch.linalg.norm(a, dim=axis) def any(self, a): return torch.any(a) @@ -1986,8 +1986,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return torch.diag(a, diagonal=k) - def unique(self, a): - return torch.unique(a) + def unique(self, a, return_inverse=False): + return torch.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): if axis is not None: @@ -2306,8 +2306,8 @@ def power(self, a, exponents): def dot(self, a, b): return cp.dot(a, b) - def norm(self, a): - return cp.sqrt(cp.sum(cp.square(a))) + def norm(self, a, axis=None): + return cp.linalg.norm(a, axis=axis) def any(self, a): return cp.any(a) @@ -2383,8 +2383,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return cp.diag(a, k) - def unique(self, a): - return cp.unique(a) + def unique(self, a, return_inverse=False): + return cp.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): # Taken from @@ -2717,8 +2717,8 @@ def sqrt(self, a): def power(self, a, exponents): return tnp.power(a, exponents) - def norm(self, a): - return tf.math.reduce_euclidean_norm(a) + def norm(self, a, axis=None): + return tf.math.reduce_euclidean_norm(a, axis=axis) def any(self, a): return tnp.any(a) @@ -2790,8 +2790,10 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return tnp.diag(a, k) - def unique(self, a): - return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) + def unique(self, a, return_inverse=False): + y, idx = tf.unique(tf.reshape(a, [-1])) + sort_idx = tf.argsort(y) + return y[sort_idx] if not return_inverse else (y[sort_idx], idx[sort_idx]) def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) diff --git a/ot/da.py b/ot/da.py index 5d55f53c7..91d2f6bf7 100644 --- a/ot/da.py +++ b/ot/da.py @@ -235,16 +235,14 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, a, labels_a, b, M = list_to_array(a, labels_a, b, M) nx = get_backend(a, labels_a, b, M) - lstlab = nx.unique(labels_a) - + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] def f(G): - res = 0 - for i in range(G.shape[1]): - for lab in lstlab: - temp = G[labels_a == lab, i] - res += nx.norm(temp) - return res + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) + return nx.norm(G_split * unroll_labels_idx, axis=1).sum() + lstlab = nx.unique(labels_a) def df(G): W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): diff --git a/test/test_da.py b/test/test_da.py index c95d48850..347f93978 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -800,3 +800,30 @@ def test_emd_laplace_class(nx): transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) + + +def test_sinkhorn_l1l2_gl_cost_vectorized(): + n_samples, n_labels = 150, 3 + rng = np.random.RandomState(42) + G = rng.rand(n_samples, n_samples) + labels_a = rng.randint(n_labels, size=(n_samples,)) + + # previously used implementation for the cost estimator + lstlab = np.unique(labels_a) + def f(G): + res = 0 + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + res += np.linalg.norm(temp) + return res + + # new vectorized implementation for the cost estimator + lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True) + n_samples = lstlab.shape[0] + midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx] + def f2(G): + G_split = np.repeat(G.T[:, :, None], n_samples, axis=2) + return np.linalg.norm(G_split * midx, axis=1).sum() + + assert np.allclose(f(G), f2(G)) \ No newline at end of file From dcb19c66878a530b4d975413be81360b64cfb16d Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 20 Aug 2023 18:37:09 +0200 Subject: [PATCH 02/10] Fix flake --- ot/da.py | 2 ++ test/test_da.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index 91d2f6bf7..f92fc535b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -238,11 +238,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) n_labels = labels_u.shape[0] unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] + def f(G): G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) return nx.norm(G_split * unroll_labels_idx, axis=1).sum() lstlab = nx.unique(labels_a) + def df(G): W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): diff --git a/test/test_da.py b/test/test_da.py index 347f93978..ad4c661b1 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -810,6 +810,7 @@ def test_sinkhorn_l1l2_gl_cost_vectorized(): # previously used implementation for the cost estimator lstlab = np.unique(labels_a) + def f(G): res = 0 for i in range(G.shape[1]): @@ -822,8 +823,9 @@ def f(G): lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True) n_samples = lstlab.shape[0] midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx] + def f2(G): G_split = np.repeat(G.T[:, :, None], n_samples, axis=2) return np.linalg.norm(G_split * midx, axis=1).sum() - assert np.allclose(f(G), f2(G)) \ No newline at end of file + assert np.allclose(f(G), f2(G)) From 6dd98770709a1e33ef4ae5309800b1a11f6097e0 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 20 Aug 2023 18:37:34 +0200 Subject: [PATCH 03/10] Convert to double befoe computing torch norm --- ot/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 3694ba350..0a583c45a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1882,7 +1882,7 @@ def power(self, a, exponents): return torch.pow(a, exponents) def norm(self, a, axis=None): - return torch.linalg.norm(a, dim=axis) + return torch.linalg.norm(a.double(), dim=axis) def any(self, a): return torch.any(a) From eb0ca02048bf119b5a79a69c1b2af52d04e4f0e8 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 20 Aug 2023 21:45:29 +0200 Subject: [PATCH 04/10] Vectorize gradient for sinkhorn l1l2 --- ot/backend.py | 22 +++++++++++----------- ot/da.py | 22 +++++++++------------- test/test_da.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 0a583c45a..5f18ee1ca 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -407,7 +407,7 @@ def power(self, a, exponents): """ raise NotImplementedError() - def norm(self, a, axis=None): + def norm(self, a, axis=None, keepdims=False): r""" Computes the matrix frobenius norm. @@ -1087,8 +1087,8 @@ def sqrt(self, a): def power(self, a, exponents): return np.power(a, exponents) - def norm(self, a, axis=None): - return np.linalg.norm(a, axis=axis) + def norm(self, a, axis=None, keepdims=False): + return np.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return np.any(a) @@ -1461,8 +1461,8 @@ def sqrt(self, a): def power(self, a, exponents): return jnp.power(a, exponents) - def norm(self, a, axis=None): - return jnp.linalg.norm(a, axis=axis) + def norm(self, a, axis=None, keepdims=False): + return jnp.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return jnp.any(a) @@ -1881,8 +1881,8 @@ def sqrt(self, a): def power(self, a, exponents): return torch.pow(a, exponents) - def norm(self, a, axis=None): - return torch.linalg.norm(a.double(), dim=axis) + def norm(self, a, axis=None, keepdims=False): + return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims) def any(self, a): return torch.any(a) @@ -2306,8 +2306,8 @@ def power(self, a, exponents): def dot(self, a, b): return cp.dot(a, b) - def norm(self, a, axis=None): - return cp.linalg.norm(a, axis=axis) + def norm(self, a, axis=None, keepdims=False): + return cp.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return cp.any(a) @@ -2717,8 +2717,8 @@ def sqrt(self, a): def power(self, a, exponents): return tnp.power(a, exponents) - def norm(self, a, axis=None): - return tf.math.reduce_euclidean_norm(a, axis=axis) + def norm(self, a, axis=None, keepdims=False): + return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims) def any(self, a): return tnp.any(a) diff --git a/ot/da.py b/ot/da.py index f92fc535b..e1668aceb 100644 --- a/ot/da.py +++ b/ot/da.py @@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False): + numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12, + verbose=False, log=False): r""" Solve the entropic regularization optimal transport problem with group lasso regularization @@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Max number of iterations (inner sinkhorn solver) stopInnerThr : float, optional Stop threshold on error (inner sinkhorn solver) (>0) + eps: float, optional (default=1e-12) + Small value to avoid division by zero verbose : bool, optional Print information along iterations log : bool, optional @@ -241,19 +243,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def f(G): G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) - return nx.norm(G_split * unroll_labels_idx, axis=1).sum() - - lstlab = nx.unique(labels_a) + return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1)) def df(G): - W = nx.zeros(G.shape, type_as=G) - for i in range(G.shape[1]): - for lab in lstlab: - temp = G[labels_a == lab, i] - n = nx.norm(temp) - if n: - W[labels_a == lab, i] = temp / n - return W + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx + W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True) + G_norm = G_split / nx.clip(W, eps, None) + return nx.sum(G_norm, axis=2).T return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, numInnerItermax=numInnerItermax, stopThr=stopInnerThr, diff --git a/test/test_da.py b/test/test_da.py index ad4c661b1..7c067ae64 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -802,30 +802,48 @@ def test_emd_laplace_class(nx): assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) -def test_sinkhorn_l1l2_gl_cost_vectorized(): +def test_sinkhorn_l1l2_gl_cost_vectorized(nx): n_samples, n_labels = 150, 3 rng = np.random.RandomState(42) G = rng.rand(n_samples, n_samples) labels_a = rng.randint(n_labels, size=(n_samples,)) + G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a) # previously used implementation for the cost estimator - lstlab = np.unique(labels_a) + lstlab = nx.unique(labels_a) def f(G): res = 0 for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - res += np.linalg.norm(temp) + res += nx.norm(temp) return res + def df(G): + W = nx.zeros(G.shape, type_as=G) + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + n = nx.norm(temp) + if n: + W[labels_a == lab, i] = temp / n + return W + # new vectorized implementation for the cost estimator - lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True) - n_samples = lstlab.shape[0] - midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx] + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] def f2(G): - G_split = np.repeat(G.T[:, :, None], n_samples, axis=2) - return np.linalg.norm(G_split * midx, axis=1).sum() + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) + return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1)) + + def df2(G): + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx + W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True) + G_norm = G_split / nx.clip(W, 1e-12, None) + return nx.sum(G_norm, axis=2).T assert np.allclose(f(G), f2(G)) + assert np.allclose(df(G), df2(G)) From 5b3d116fe068487594e60470d8a9294e2234625f Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 21 Aug 2023 15:34:00 +0200 Subject: [PATCH 05/10] Skip TF and JAX backends (as it is done for other DA tests) --- test/test_da.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_da.py b/test/test_da.py index 7c067ae64..58a58c17e 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -802,6 +802,8 @@ def test_emd_laplace_class(nx): assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") def test_sinkhorn_l1l2_gl_cost_vectorized(nx): n_samples, n_labels = 150, 3 rng = np.random.RandomState(42) From c249be060ff2cbcfe14dc57da3c1e3b2731e5f76 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 21 Aug 2023 15:59:28 +0200 Subject: [PATCH 06/10] Additional tests for norm() and unique() --- test/test_backend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index f0571471c..b33f935c3 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -411,6 +411,14 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('norm') + A = nx.norm(Mb, axis=1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm(M,axis=1)') + + A = nx.norm(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm(M,axis=1,keepdims=True)') + A = nx.any(vb > 0) lst_b.append(nx.to_numpy(A)) lst_name.append('any') @@ -517,6 +525,12 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('unique') + A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique(M,return_inverse=True)[0]') + lst_b.append(nx.to_numpy(A2)) + lst_name.append('unique(M,return_inverse=True)[1]') + A = nx.logsumexp(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('logsumexp') From 5ac5553046da70e43f01479f83a6486834e98741 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 21 Aug 2023 16:30:40 +0200 Subject: [PATCH 07/10] TF implementation of unique should inverse sort permutation for indicies --- ot/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 5f18ee1ca..09fd7341c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2793,7 +2793,7 @@ def diag(self, a, k=0): def unique(self, a, return_inverse=False): y, idx = tf.unique(tf.reshape(a, [-1])) sort_idx = tf.argsort(y) - return y[sort_idx] if not return_inverse else (y[sort_idx], idx[sort_idx]) + return y[sort_idx] if not return_inverse else (y[sort_idx], sort_idx[idx]) def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) From ae84df5fcf5d03843b011f7446b7f4377bebf820 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 21 Aug 2023 17:20:46 +0200 Subject: [PATCH 08/10] Fix sort indicies for unique(), switched to tf.gather --- ot/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 09fd7341c..c1da39b70 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2793,7 +2793,8 @@ def diag(self, a, k=0): def unique(self, a, return_inverse=False): y, idx = tf.unique(tf.reshape(a, [-1])) sort_idx = tf.argsort(y) - return y[sort_idx] if not return_inverse else (y[sort_idx], sort_idx[idx]) + y_prime = tf.gather(y, sort_idx) + return y_prime if not return_inverse else (y_prime, tf.gather(y, idx)) def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) From ba688e4d78e7babf52073681276a380f45a50b01 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 21 Aug 2023 17:56:01 +0200 Subject: [PATCH 09/10] tf.math.invert_permutation --- ot/backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index c1da39b70..359903167 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2794,7 +2794,11 @@ def unique(self, a, return_inverse=False): y, idx = tf.unique(tf.reshape(a, [-1])) sort_idx = tf.argsort(y) y_prime = tf.gather(y, sort_idx) - return y_prime if not return_inverse else (y_prime, tf.gather(y, idx)) + if return_inverse: + inv_sort_idx = tf.math.invert_permutation(sort_idx) + return y_prime, tf.gather(inv_sort_idx, idx) + else: + return y_prime def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) From 76288e553887b8e22f4717d590329c17bcd08a91 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 20 Sep 2023 16:30:31 +0200 Subject: [PATCH 10/10] Mention changes in changelog --- RELEASES.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index d0209e233..639ae4d1e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.9.2 + +#### New features +- Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507) + + ## 0.9.1 *August 2023*