diff --git a/RELEASES.md b/RELEASES.md index 4eeea9c66..4ee33917d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features + Tweaked `get_backend` to ignore `None` inputs (PR # 525) ++ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/backend.py b/ot/backend.py index a80c5ae73..288224d7c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -411,7 +411,7 @@ def power(self, a, exponents): """ raise NotImplementedError() - def norm(self, a): + def norm(self, a, axis=None, keepdims=False): r""" Computes the matrix frobenius norm. @@ -631,7 +631,7 @@ def diag(self, a, k=0): """ raise NotImplementedError() - def unique(self, a): + def unique(self, a, return_inverse=False): r""" Finds unique elements of given tensor. @@ -1091,8 +1091,8 @@ def sqrt(self, a): def power(self, a, exponents): return np.power(a, exponents) - def norm(self, a): - return np.sqrt(np.sum(np.square(a))) + def norm(self, a, axis=None, keepdims=False): + return np.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return np.any(a) @@ -1168,8 +1168,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return np.diag(a, k) - def unique(self, a): - return np.unique(a) + def unique(self, a, return_inverse=False): + return np.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): return special.logsumexp(a, axis=axis) @@ -1465,8 +1465,8 @@ def sqrt(self, a): def power(self, a, exponents): return jnp.power(a, exponents) - def norm(self, a): - return jnp.sqrt(jnp.sum(jnp.square(a))) + def norm(self, a, axis=None, keepdims=False): + return jnp.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return jnp.any(a) @@ -1539,8 +1539,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return jnp.diag(a, k) - def unique(self, a): - return jnp.unique(a) + def unique(self, a, return_inverse=False): + return jnp.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): return jspecial.logsumexp(a, axis=axis) @@ -1885,8 +1885,8 @@ def sqrt(self, a): def power(self, a, exponents): return torch.pow(a, exponents) - def norm(self, a): - return torch.sqrt(torch.sum(torch.square(a))) + def norm(self, a, axis=None, keepdims=False): + return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims) def any(self, a): return torch.any(a) @@ -1990,8 +1990,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return torch.diag(a, diagonal=k) - def unique(self, a): - return torch.unique(a) + def unique(self, a, return_inverse=False): + return torch.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): if axis is not None: @@ -2310,8 +2310,8 @@ def power(self, a, exponents): def dot(self, a, b): return cp.dot(a, b) - def norm(self, a): - return cp.sqrt(cp.sum(cp.square(a))) + def norm(self, a, axis=None, keepdims=False): + return cp.linalg.norm(a, axis=axis, keepdims=keepdims) def any(self, a): return cp.any(a) @@ -2387,8 +2387,8 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return cp.diag(a, k) - def unique(self, a): - return cp.unique(a) + def unique(self, a, return_inverse=False): + return cp.unique(a, return_inverse=return_inverse) def logsumexp(self, a, axis=None): # Taken from @@ -2721,8 +2721,8 @@ def sqrt(self, a): def power(self, a, exponents): return tnp.power(a, exponents) - def norm(self, a): - return tf.math.reduce_euclidean_norm(a) + def norm(self, a, axis=None, keepdims=False): + return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims) def any(self, a): return tnp.any(a) @@ -2794,8 +2794,15 @@ def meshgrid(self, a, b): def diag(self, a, k=0): return tnp.diag(a, k) - def unique(self, a): - return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) + def unique(self, a, return_inverse=False): + y, idx = tf.unique(tf.reshape(a, [-1])) + sort_idx = tf.argsort(y) + y_prime = tf.gather(y, sort_idx) + if return_inverse: + inv_sort_idx = tf.math.invert_permutation(sort_idx) + return y_prime, tf.gather(inv_sort_idx, idx) + else: + return y_prime def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) diff --git a/ot/da.py b/ot/da.py index dc6aa70a0..1df9a39af 100644 --- a/ot/da.py +++ b/ot/da.py @@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False): + numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12, + verbose=False, log=False): r""" Solve the entropic regularization optimal transport problem with group lasso regularization @@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Max number of iterations (inner sinkhorn solver) stopInnerThr : float, optional Stop threshold on error (inner sinkhorn solver) (>0) + eps: float, optional (default=1e-12) + Small value to avoid division by zero verbose : bool, optional Print information along iterations log : bool, optional @@ -235,25 +237,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, a, labels_a, b, M = list_to_array(a, labels_a, b, M) nx = get_backend(a, labels_a, b, M) - lstlab = nx.unique(labels_a) + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] def f(G): - res = 0 - for i in range(G.shape[1]): - for lab in lstlab: - temp = G[labels_a == lab, i] - res += nx.norm(temp) - return res + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) + return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1)) def df(G): - W = nx.zeros(G.shape, type_as=G) - for i in range(G.shape[1]): - for lab in lstlab: - temp = G[labels_a == lab, i] - n = nx.norm(temp) - if n: - W[labels_a == lab, i] = temp / n - return W + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx + W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True) + G_norm = G_split / nx.clip(W, eps, None) + return nx.sum(G_norm, axis=2).T return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, numInnerItermax=numInnerItermax, stopThr=stopInnerThr, diff --git a/test/test_backend.py b/test/test_backend.py index bfca6139d..8ab861078 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -412,6 +412,14 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('norm') + A = nx.norm(Mb, axis=1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm(M,axis=1)') + + A = nx.norm(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm(M,axis=1,keepdims=True)') + A = nx.any(vb > 0) lst_b.append(nx.to_numpy(A)) lst_name.append('any') @@ -518,6 +526,12 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('unique') + A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique(M,return_inverse=True)[0]') + lst_b.append(nx.to_numpy(A2)) + lst_name.append('unique(M,return_inverse=True)[1]') + A = nx.logsumexp(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('logsumexp') diff --git a/test/test_da.py b/test/test_da.py index 0a4b10c4a..49df16d00 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -801,3 +801,52 @@ def test_emd_laplace_class(nx): transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_l1l2_gl_cost_vectorized(nx): + n_samples, n_labels = 150, 3 + rng = np.random.RandomState(42) + G = rng.rand(n_samples, n_samples) + labels_a = rng.randint(n_labels, size=(n_samples,)) + G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a) + + # previously used implementation for the cost estimator + lstlab = nx.unique(labels_a) + + def f(G): + res = 0 + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + res += nx.norm(temp) + return res + + def df(G): + W = nx.zeros(G.shape, type_as=G) + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + n = nx.norm(temp) + if n: + W[labels_a == lab, i] = temp / n + return W + + # new vectorized implementation for the cost estimator + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx] + + def f2(G): + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) + return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1)) + + def df2(G): + G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx + W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True) + G_norm = G_split / nx.clip(W, 1e-12, None) + return nx.sum(G_norm, axis=2).T + + assert np.allclose(f(G), f2(G)) + assert np.allclose(df(G), df2(G))