Skip to content

Commit 6f35804

Browse files
authored
[MRG] Add implicit Sinkhorn gradients (#605)
* add detach function to backend * debug function * better detach * new implementation * add test for gradient * better default * update documentation
1 parent c84ef33 commit 6f35804

File tree

4 files changed

+130
-40
lines changed

4 files changed

+130
-40
lines changed

ot/backend.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,19 @@ def set_gradients(self, val, inputs, grads):
281281
"""Define the gradients for the value val wrt the inputs """
282282
raise NotImplementedError()
283283

284+
def detach(self, *arrays):
285+
"""Detach the tensors from the computation graph
286+
287+
See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html"""
288+
if len(arrays) == 1:
289+
return self._detach(arrays[0])
290+
else:
291+
return [self._detach(array) for array in arrays]
292+
293+
def _detach(self, a):
294+
"""Detach the tensor from the computation graph"""
295+
raise NotImplementedError()
296+
284297
def zeros(self, shape, type_as=None):
285298
r"""
286299
Creates a tensor full of zeros.
@@ -1027,14 +1040,6 @@ def transpose(self, a, axes=None):
10271040
"""
10281041
raise NotImplementedError()
10291042

1030-
def detach(self, *args):
1031-
r"""
1032-
Detach tensors in arguments from the current graph.
1033-
1034-
See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
1035-
"""
1036-
raise NotImplementedError()
1037-
10381043
def matmul(self, a, b):
10391044
r"""
10401045
Matrix product of two arrays.
@@ -1082,6 +1087,10 @@ def set_gradients(self, val, inputs, grads):
10821087
# No gradients for numpy
10831088
return val
10841089

1090+
def _detach(self, a):
1091+
# No gradients for numpy
1092+
return a
1093+
10851094
def zeros(self, shape, type_as=None):
10861095
if type_as is None:
10871096
return np.zeros(shape)
@@ -1392,11 +1401,6 @@ def atan2(self, a, b):
13921401
def transpose(self, a, axes=None):
13931402
return np.transpose(a, axes)
13941403

1395-
def detach(self, *args):
1396-
if len(args) == 1:
1397-
return args[0]
1398-
return args
1399-
14001404
def matmul(self, a, b):
14011405
return np.matmul(a, b)
14021406

@@ -1462,6 +1466,9 @@ def set_gradients(self, val, inputs, grads):
14621466
val, = jax.tree_map(lambda z: z + aux, (val,))
14631467
return val
14641468

1469+
def _detach(self, a):
1470+
return jax.lax.stop_gradient(a)
1471+
14651472
def zeros(self, shape, type_as=None):
14661473
if type_as is None:
14671474
return jnp.zeros(shape)
@@ -1765,11 +1772,6 @@ def atan2(self, a, b):
17651772
def transpose(self, a, axes=None):
17661773
return jnp.transpose(a, axes)
17671774

1768-
def detach(self, *args):
1769-
if len(args) == 1:
1770-
return jax.lax.stop_gradient((args[0],))[0]
1771-
return [jax.lax.stop_gradient((a,))[0] for a in args]
1772-
17731775
def matmul(self, a, b):
17741776
return jnp.matmul(a, b)
17751777

@@ -1851,6 +1853,9 @@ def set_gradients(self, val, inputs, grads):
18511853

18521854
return res
18531855

1856+
def _detach(self, a):
1857+
return a.detach()
1858+
18541859
def zeros(self, shape, type_as=None):
18551860
if isinstance(shape, int):
18561861
shape = (shape,)
@@ -2256,11 +2261,6 @@ def transpose(self, a, axes=None):
22562261
axes = tuple(range(a.ndim)[::-1])
22572262
return a.permute(axes)
22582263

2259-
def detach(self, *args):
2260-
if len(args) == 1:
2261-
return args[0].detach()
2262-
return [a.detach() for a in args]
2263-
22642264
def matmul(self, a, b):
22652265
return torch.matmul(a, b)
22662266

@@ -2312,6 +2312,9 @@ def set_gradients(self, val, inputs, grads):
23122312
# No gradients for cupy
23132313
return val
23142314

2315+
def _detach(self, a):
2316+
return a
2317+
23152318
def zeros(self, shape, type_as=None):
23162319
if isinstance(shape, (list, tuple)):
23172320
shape = tuple(int(i) for i in shape)
@@ -2657,11 +2660,6 @@ def atan2(self, a, b):
26572660
def transpose(self, a, axes=None):
26582661
return cp.transpose(a, axes)
26592662

2660-
def detach(self, *args):
2661-
if len(args) == 1:
2662-
return args[0]
2663-
return args
2664-
26652663
def matmul(self, a, b):
26662664
return cp.matmul(a, b)
26672665

@@ -2729,6 +2727,9 @@ def grad(upstream):
27292727
return val, grad
27302728
return tmp(inputs)
27312729

2730+
def _detach(self, a):
2731+
return tf.stop_gradient(a)
2732+
27322733
def zeros(self, shape, type_as=None):
27332734
if type_as is None:
27342735
return tnp.zeros(shape)
@@ -3083,11 +3084,6 @@ def atan2(self, a, b):
30833084
def transpose(self, a, axes=None):
30843085
return tf.transpose(a, perm=axes)
30853086

3086-
def detach(self, *args):
3087-
if len(args) == 1:
3088-
return tf.stop_gradient(args[0])
3089-
return [tf.stop_gradient(a) for a in args]
3090-
30913087
def matmul(self, a, b):
30923088
return tnp.matmul(a, b)
30933089

ot/solvers.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
3131
unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None,
32-
potentials_init=None, tol=None, verbose=False):
32+
potentials_init=None, tol=None, verbose=False, grad='autodiff'):
3333
r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
3434
3535
The function solves the following general optimal transport problem
@@ -79,6 +79,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
7979
Tolerance for solution precision, by default None (default values in each solvers)
8080
verbose : bool, optional
8181
Print information in the solver, by default False
82+
grad : str, optional
83+
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
84+
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
85+
outputs (`plan, value, value_linear`) but with important memory cost.
86+
'implicit' provides gradients only for `value` and and other outputs are
87+
detached. This is useful for memory saving when only the value is needed.
8288
8389
Returns
8490
-------
@@ -134,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
134140
# or for original Sinkhorn paper formulation [2]
135141
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
136142
143+
# Use implicit differentiation for memory saving
144+
res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
145+
res.value.backward() # only the value is differentiable
146+
147+
Note that by default the Sinkhorn solver uses automatic differentiation to
148+
compute the gradients of the values and plan. This can be changed with the
149+
`grad` parameter. The `implicit` mode computes the implicit gradients only
150+
for the value and the other outputs are detached. This is useful for
151+
memory saving when only the gradient of value is needed.
152+
137153
- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
138154
139155
.. math::
@@ -297,6 +313,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
297313

298314
if reg_type.lower() in ['entropy', 'kl']:
299315

316+
if grad == 'implicit': # if implicit then detach the input
317+
M0, a0, b0 = M, a, b
318+
M, a, b = nx.detach(M, a, b)
319+
300320
# default values for sinkhorn
301321
if max_iter is None:
302322
max_iter = 1000
@@ -316,6 +336,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
316336

317337
potentials = (log['log_u'], log['log_v'])
318338

339+
if grad == 'implicit': # set the gradient at convergence
340+
341+
value = nx.set_gradients(value, (M0, a0, b0),
342+
(plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean())))
343+
319344
elif reg_type.lower() == 'l2':
320345

321346
if max_iter is None:
@@ -869,7 +894,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
869894
def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL",
870895
unbalanced=None,
871896
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95,
872-
potentials_init=None, X_init=None, tol=None, verbose=False):
897+
potentials_init=None, X_init=None, tol=None, verbose=False,
898+
grad='autodiff'):
873899
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
874900
875901
The function solves the following general optimal transport problem
@@ -935,6 +961,12 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
935961
Tolerance for solution precision, by default None (default values in each solvers)
936962
verbose : bool, optional
937963
Print information in the solver, by default False
964+
grad : str, optional
965+
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
966+
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
967+
outputs (`plan, value, value_linear`) but with important memory cost.
968+
'implicit' provides gradients only for `value` and and other outputs are
969+
detached. This is useful for memory saving when only the value is needed.
938970
939971
Returns
940972
-------
@@ -1002,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
10021034
# lazy OT plan
10031035
lazy_plan = res.lazy_plan
10041036
1037+
# Use implicit differentiation for memory saving
1038+
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
1039+
res.value.backward() # only the value is differentiable
1040+
1041+
Note that by default the Sinkhorn solver uses automatic differentiation to
1042+
compute the gradients of the values and plan. This can be changed with the
1043+
`grad` parameter. The `implicit` mode computes the implicit gradients only
1044+
for the value and the other outputs are detached. This is useful for
1045+
memory saving when only the gradient of value is needed.
1046+
10051047
We also have a very efficient solver with compiled CPU/CUDA code using
10061048
geomloss/PyKeOps that can be used with the following code:
10071049
@@ -1189,7 +1231,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
11891231
# compute cost matrix M and use solve function
11901232
M = dist(X_a, X_b, metric)
11911233

1192-
res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose)
1234+
res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad)
11931235

11941236
return res
11951237

test/test_backend.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,14 @@ def test_empty_backend():
266266
nx.matmul(M, M.T)
267267
with pytest.raises(NotImplementedError):
268268
nx.nan_to_num(M)
269+
with pytest.raises(NotImplementedError):
270+
nx.sign(M)
271+
with pytest.raises(NotImplementedError):
272+
nx.dtype_device(M)
273+
with pytest.raises(NotImplementedError):
274+
nx.assert_same_dtype_device(M, M)
275+
with pytest.raises(NotImplementedError):
276+
nx.eigh(M)
269277

270278

271279
def test_func_backends(nx):
@@ -311,6 +319,11 @@ def test_func_backends(nx):
311319
lst_b.append(nx.to_numpy(A))
312320
lst_name.append('set_gradients')
313321

322+
A = nx.detach(Mb)
323+
A, B = nx.detach(Mb, Mb)
324+
lst_b.append(nx.to_numpy(A))
325+
lst_name.append('detach')
326+
314327
A = nx.zeros((10, 3))
315328
A = nx.zeros((10, 3), type_as=Mb)
316329
lst_b.append(nx.to_numpy(A))
@@ -652,10 +665,6 @@ def test_func_backends(nx):
652665
lst_b.append(nx.to_numpy(A))
653666
lst_name.append("transpose")
654667

655-
A = nx.detach(Mb)
656-
lst_b.append(nx.to_numpy(A))
657-
lst_name.append("detach")
658-
659668
A, B = nx.detach(Mb, Mb)
660669
lst_b.append(nx.to_numpy(A))
661670
lst_name.append("detach A")

test/test_solvers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import ot
1414
from ot.bregman import geomloss
15+
from ot.backend import torch
1516

1617
lst_reg = [None, 1]
1718
lst_reg_type = ['KL', 'entropy', 'L2']
@@ -107,6 +108,48 @@ def test_solve(nx):
107108
sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
108109

109110

111+
@pytest.mark.skipif(not torch, reason="torch no installed")
112+
def test_solve_implicit():
113+
114+
n_samples_s = 10
115+
n_samples_t = 7
116+
n_features = 2
117+
rng = np.random.RandomState(0)
118+
119+
x = rng.randn(n_samples_s, n_features)
120+
y = rng.randn(n_samples_t, n_features)
121+
a = ot.utils.unif(n_samples_s)
122+
b = ot.utils.unif(n_samples_t)
123+
M = ot.dist(x, y)
124+
125+
a = torch.tensor(a, requires_grad=True)
126+
b = torch.tensor(b, requires_grad=True)
127+
M = torch.tensor(M, requires_grad=True)
128+
129+
sol0 = ot.solve(M, a, b, reg=10, grad='implicit')
130+
sol0.value.backward()
131+
132+
gM0 = M.grad.clone()
133+
ga0 = a.grad.clone()
134+
gb0 = b.grad.clone()
135+
136+
a = torch.tensor(a, requires_grad=True)
137+
b = torch.tensor(b, requires_grad=True)
138+
M = torch.tensor(M, requires_grad=True)
139+
140+
sol = ot.solve(M, a, b, reg=10, grad='autodiff')
141+
sol.value.backward()
142+
143+
gM = M.grad.clone()
144+
ga = a.grad.clone()
145+
gb = b.grad.clone()
146+
147+
# Note, gradients aer invariant to change in constant so we center them
148+
assert torch.allclose(gM0, gM)
149+
assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean())
150+
assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean())
151+
152+
110153
@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
111154
def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
112155
n_samples_s = 10

0 commit comments

Comments
 (0)