Skip to content

Vectorize cost and gradient for ot.da.sinkhorn_l1l2_gl #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
51 changes: 29 additions & 22 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def power(self, a, exponents):
"""
raise NotImplementedError()

def norm(self, a):
def norm(self, a, axis=None, keepdims=False):
r"""
Computes the matrix frobenius norm.

Expand Down Expand Up @@ -631,7 +631,7 @@ def diag(self, a, k=0):
"""
raise NotImplementedError()

def unique(self, a):
def unique(self, a, return_inverse=False):
r"""
Finds unique elements of given tensor.

Expand Down Expand Up @@ -1091,8 +1091,8 @@ def sqrt(self, a):
def power(self, a, exponents):
return np.power(a, exponents)

def norm(self, a):
return np.sqrt(np.sum(np.square(a)))
def norm(self, a, axis=None, keepdims=False):
return np.linalg.norm(a, axis=axis, keepdims=keepdims)

def any(self, a):
return np.any(a)
Expand Down Expand Up @@ -1168,8 +1168,8 @@ def meshgrid(self, a, b):
def diag(self, a, k=0):
return np.diag(a, k)

def unique(self, a):
return np.unique(a)
def unique(self, a, return_inverse=False):
return np.unique(a, return_inverse=return_inverse)

def logsumexp(self, a, axis=None):
return special.logsumexp(a, axis=axis)
Expand Down Expand Up @@ -1465,8 +1465,8 @@ def sqrt(self, a):
def power(self, a, exponents):
return jnp.power(a, exponents)

def norm(self, a):
return jnp.sqrt(jnp.sum(jnp.square(a)))
def norm(self, a, axis=None, keepdims=False):
return jnp.linalg.norm(a, axis=axis, keepdims=keepdims)

def any(self, a):
return jnp.any(a)
Expand Down Expand Up @@ -1539,8 +1539,8 @@ def meshgrid(self, a, b):
def diag(self, a, k=0):
return jnp.diag(a, k)

def unique(self, a):
return jnp.unique(a)
def unique(self, a, return_inverse=False):
return jnp.unique(a, return_inverse=return_inverse)

def logsumexp(self, a, axis=None):
return jspecial.logsumexp(a, axis=axis)
Expand Down Expand Up @@ -1885,8 +1885,8 @@ def sqrt(self, a):
def power(self, a, exponents):
return torch.pow(a, exponents)

def norm(self, a):
return torch.sqrt(torch.sum(torch.square(a)))
def norm(self, a, axis=None, keepdims=False):
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)

def any(self, a):
return torch.any(a)
Expand Down Expand Up @@ -1990,8 +1990,8 @@ def meshgrid(self, a, b):
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)

def unique(self, a):
return torch.unique(a)
def unique(self, a, return_inverse=False):
return torch.unique(a, return_inverse=return_inverse)

def logsumexp(self, a, axis=None):
if axis is not None:
Expand Down Expand Up @@ -2310,8 +2310,8 @@ def power(self, a, exponents):
def dot(self, a, b):
return cp.dot(a, b)

def norm(self, a):
return cp.sqrt(cp.sum(cp.square(a)))
def norm(self, a, axis=None, keepdims=False):
return cp.linalg.norm(a, axis=axis, keepdims=keepdims)

def any(self, a):
return cp.any(a)
Expand Down Expand Up @@ -2387,8 +2387,8 @@ def meshgrid(self, a, b):
def diag(self, a, k=0):
return cp.diag(a, k)

def unique(self, a):
return cp.unique(a)
def unique(self, a, return_inverse=False):
return cp.unique(a, return_inverse=return_inverse)

def logsumexp(self, a, axis=None):
# Taken from
Expand Down Expand Up @@ -2721,8 +2721,8 @@ def sqrt(self, a):
def power(self, a, exponents):
return tnp.power(a, exponents)

def norm(self, a):
return tf.math.reduce_euclidean_norm(a)
def norm(self, a, axis=None, keepdims=False):
return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims)

def any(self, a):
return tnp.any(a)
Expand Down Expand Up @@ -2794,8 +2794,15 @@ def meshgrid(self, a, b):
def diag(self, a, k=0):
return tnp.diag(a, k)

def unique(self, a):
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
def unique(self, a, return_inverse=False):
y, idx = tf.unique(tf.reshape(a, [-1]))
sort_idx = tf.argsort(y)
y_prime = tf.gather(y, sort_idx)
if return_inverse:
inv_sort_idx = tf.math.invert_permutation(sort_idx)
return y_prime, tf.gather(inv_sort_idx, idx)
else:
return y_prime

def logsumexp(self, a, axis=None):
return tf.math.reduce_logsumexp(a, axis=axis)
Expand Down
30 changes: 13 additions & 17 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,


def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12,
verbose=False, log=False):
r"""
Solve the entropic regularization optimal transport problem with group
lasso regularization
Expand Down Expand Up @@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Max number of iterations (inner sinkhorn solver)
stopInnerThr : float, optional
Stop threshold on error (inner sinkhorn solver) (>0)
eps: float, optional (default=1e-12)
Small value to avoid division by zero
verbose : bool, optional
Print information along iterations
log : bool, optional
Expand Down Expand Up @@ -235,25 +237,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
a, labels_a, b, M = list_to_array(a, labels_a, b, M)
nx = get_backend(a, labels_a, b, M)

lstlab = nx.unique(labels_a)
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
n_labels = labels_u.shape[0]
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]

def f(G):
res = 0
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
res += nx.norm(temp)
return res
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))

def df(G):
W = nx.zeros(G.shape, type_as=G)
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
n = nx.norm(temp)
if n:
W[labels_a == lab, i] = temp / n
return W
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
G_norm = G_split / nx.clip(W, eps, None)
return nx.sum(G_norm, axis=2).T

return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
numInnerItermax=numInnerItermax, stopThr=stopInnerThr,
Expand Down
14 changes: 14 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('norm')

A = nx.norm(Mb, axis=1)
lst_b.append(nx.to_numpy(A))
lst_name.append('norm(M,axis=1)')

A = nx.norm(Mb, axis=1, keepdims=True)
lst_b.append(nx.to_numpy(A))
lst_name.append('norm(M,axis=1,keepdims=True)')

A = nx.any(vb > 0)
lst_b.append(nx.to_numpy(A))
lst_name.append('any')
Expand Down Expand Up @@ -518,6 +526,12 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('unique')

A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True)
lst_b.append(nx.to_numpy(A))
lst_name.append('unique(M,return_inverse=True)[0]')
lst_b.append(nx.to_numpy(A2))
lst_name.append('unique(M,return_inverse=True)[1]')

A = nx.logsumexp(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('logsumexp')
Expand Down
49 changes: 49 additions & 0 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,52 @@ def test_emd_laplace_class(nx):
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_l1l2_gl_cost_vectorized(nx):
n_samples, n_labels = 150, 3
rng = np.random.RandomState(42)
G = rng.rand(n_samples, n_samples)
labels_a = rng.randint(n_labels, size=(n_samples,))
G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a)

# previously used implementation for the cost estimator
lstlab = nx.unique(labels_a)

def f(G):
res = 0
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
res += nx.norm(temp)
return res

def df(G):
W = nx.zeros(G.shape, type_as=G)
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
n = nx.norm(temp)
if n:
W[labels_a == lab, i] = temp / n
return W

# new vectorized implementation for the cost estimator
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
n_labels = labels_u.shape[0]
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]

def f2(G):
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))

def df2(G):
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
G_norm = G_split / nx.clip(W, 1e-12, None)
return nx.sum(G_norm, axis=2).T

assert np.allclose(f(G), f2(G))
assert np.allclose(df(G), df2(G))