Skip to content

Commit 1682b60

Browse files
add documentation and warnings in (f)gw cg solvers when integers are provided (#560)
1 parent 1ece2d8 commit 1682b60

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#### Closed issues
2121
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
2222
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
23-
23+
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
2424

2525
## 0.9.1
2626
*August 2023*

ot/gromov/_gw.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# License: MIT License
1313

1414
import numpy as np
15+
import warnings
1516

1617

1718
from ..utils import dist, UndefinedParameter, list_to_array
@@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
5354
which can lead to copy overhead on GPU arrays.
5455
.. note:: All computations in the conjugate gradient solver are done with
5556
numpy to limit memory overhead.
57+
.. note:: This function will cast the computed transport plan to the data
58+
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
59+
tensor might result in a loss of precision. If this behaviour is
60+
unwanted, please make sure to provide a floating point input.
5661
5762
Parameters
5863
----------
@@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
122127
if q is not None:
123128
arr.append(list_to_array(q))
124129
else:
125-
q = unif(C2.shape[0], type_as=C2)
130+
q = unif(C2.shape[0], type_as=C1)
126131
if G0 is not None:
127132
G0_ = G0
128133
arr.append(G0)
@@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
171176
else:
172177
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
173178
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
179+
180+
if not nx.is_floating_point(C10):
181+
warnings.warn(
182+
"Input structure matrix consists of integer. The transport plan will be "
183+
"casted accordingly, possibly resulting in a loss of precision. "
184+
"If this behaviour is unwanted, please make sure your input "
185+
"structure matrix consists of floating point elements.",
186+
stacklevel=2
187+
)
188+
174189
if log:
175190
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
176191
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
@@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
216231
which can lead to copy overhead on GPU arrays.
217232
.. note:: All computations in the conjugate gradient solver are done with
218233
numpy to limit memory overhead.
234+
.. note:: This function will cast the computed transport plan to the data
235+
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
236+
tensor might result in a loss of precision. If this behaviour is
237+
unwanted, please make sure to provide a floating point input.
219238
220239
Parameters
221240
----------
@@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
286305
if p is None:
287306
p = unif(C1.shape[0], type_as=C1)
288307
if q is None:
289-
q = unif(C2.shape[0], type_as=C2)
308+
q = unif(C2.shape[0], type_as=C1)
290309

291310
T, log_gw = gromov_wasserstein(
292311
C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
@@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
344363
which can lead to copy overhead on GPU arrays.
345364
.. note:: All computations in the conjugate gradient solver are done with
346365
numpy to limit memory overhead.
366+
.. note:: This function will cast the computed transport plan to the data
367+
type of the provided input :math:`\mathbf{M}`. Casting to an integer
368+
tensor might result in a loss of precision. If this behaviour is
369+
unwanted, please make sure to provide a floating point input.
347370
348371
349372
Parameters
@@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
409432
if p is not None:
410433
arr.append(list_to_array(p))
411434
else:
412-
p = unif(C1.shape[0], type_as=C1)
435+
p = unif(C1.shape[0], type_as=M)
413436
if q is not None:
414437
arr.append(list_to_array(q))
415438
else:
416-
q = unif(C2.shape[0], type_as=C2)
439+
q = unif(C2.shape[0], type_as=M)
417440
if G0 is not None:
418441
G0_ = G0
419442
arr.append(G0)
@@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
465488
else:
466489
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
467490
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
491+
if not nx.is_floating_point(M0):
492+
warnings.warn(
493+
"Input feature matrix consists of integer. The transport plan will be "
494+
"casted accordingly, possibly resulting in a loss of precision. "
495+
"If this behaviour is unwanted, please make sure your input "
496+
"feature matrix consists of floating point elements.",
497+
stacklevel=2
498+
)
468499
if log:
469500
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
470-
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
471-
log['u'] = nx.from_numpy(log['u'], type_as=C10)
472-
log['v'] = nx.from_numpy(log['v'], type_as=C10)
473-
return nx.from_numpy(res, type_as=C10), log
501+
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0)
502+
log['u'] = nx.from_numpy(log['u'], type_as=M0)
503+
log['v'] = nx.from_numpy(log['v'], type_as=M0)
504+
return nx.from_numpy(res, type_as=M0), log
474505
else:
475-
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
506+
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0)
476507

477508

478509
def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
@@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
510541
which can lead to copy overhead on GPU arrays.
511542
.. note:: All computations in the conjugate gradient solver are done with
512543
numpy to limit memory overhead.
544+
.. note:: This function will cast the computed transport plan to the data
545+
type of the provided input :math:`\mathbf{M}`. Casting to an integer
546+
tensor might result in a loss of precision. If this behaviour is
547+
unwanted, please make sure to provide a floating point input.
513548
514549
Parameters
515550
----------
@@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
578613

579614
# init marginals if set as None
580615
if p is None:
581-
p = unif(C1.shape[0], type_as=C1)
616+
p = unif(C1.shape[0], type_as=M)
582617
if q is None:
583-
q = unif(C2.shape[0], type_as=C2)
618+
q = unif(C2.shape[0], type_as=M)
584619

585620
T, log_fgw = fused_gromov_wasserstein(
586621
M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,

test/test_gromov.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,37 @@ def test_asymmetric_gromov(nx):
122122
np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
123123

124124

125+
def test_gromov_integer_warnings(nx):
126+
n_samples = 10 # nb samples
127+
mu_s = np.array([0, 0])
128+
cov_s = np.array([[1, 0], [0, 1]])
129+
130+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
131+
xt = xs[::-1].copy()
132+
133+
p = ot.unif(n_samples)
134+
q = ot.unif(n_samples)
135+
G0 = p[:, None] * q[None, :]
136+
137+
C1 = ot.dist(xs, xs)
138+
C2 = ot.dist(xt, xt)
139+
140+
C1 /= C1.max()
141+
C2 /= C2.max()
142+
C1 = C1.astype(np.int32)
143+
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
144+
145+
G = ot.gromov.gromov_wasserstein(
146+
C1, C2, None, q, 'square_loss', G0=G0, verbose=True,
147+
alpha_min=0., alpha_max=1.)
148+
Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(
149+
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True))
150+
151+
# check constraints
152+
np.testing.assert_allclose(G, Gb, atol=1e-06)
153+
np.testing.assert_allclose(G, 0., atol=1e-09)
154+
155+
125156
def test_gromov_dtype_device(nx):
126157
# setup
127158
n_samples = 20 # nb samples
@@ -1145,7 +1176,7 @@ def test_fgw(nx):
11451176

11461177

11471178
def test_asymmetric_fgw(nx):
1148-
n_samples = 50 # nb samples
1179+
n_samples = 20 # nb samples
11491180
rng = np.random.RandomState(0)
11501181
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
11511182
idx = np.arange(n_samples)
@@ -1221,6 +1252,32 @@ def test_asymmetric_fgw(nx):
12211252
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
12221253

12231254

1255+
def test_fgw_integer_warnings(nx):
1256+
n_samples = 20 # nb samples
1257+
rng = np.random.RandomState(0)
1258+
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
1259+
idx = np.arange(n_samples)
1260+
rng.shuffle(idx)
1261+
C2 = C1[idx, :][:, idx]
1262+
1263+
# add features
1264+
F1 = rng.uniform(low=0., high=10, size=(n_samples, 1))
1265+
F2 = F1[idx, :]
1266+
p = ot.unif(n_samples)
1267+
q = ot.unif(n_samples)
1268+
G0 = p[:, None] * q[None, :]
1269+
1270+
M = ot.dist(F1, F2).astype(np.int32)
1271+
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
1272+
1273+
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
1274+
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
1275+
Gb = nx.to_numpy(Gb)
1276+
# check constraints
1277+
np.testing.assert_allclose(G, Gb, atol=1e-06)
1278+
np.testing.assert_allclose(G, 0., atol=1e-06)
1279+
1280+
12241281
def test_fgw2_gradients():
12251282
n_samples = 20 # nb samples
12261283

0 commit comments

Comments
 (0)