|
12 | 12 | # License: MIT License
|
13 | 13 |
|
14 | 14 | import numpy as np
|
| 15 | +import warnings |
15 | 16 |
|
16 | 17 |
|
17 | 18 | from ..utils import dist, UndefinedParameter, list_to_array
|
@@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
|
53 | 54 | which can lead to copy overhead on GPU arrays.
|
54 | 55 | .. note:: All computations in the conjugate gradient solver are done with
|
55 | 56 | numpy to limit memory overhead.
|
| 57 | + .. note:: This function will cast the computed transport plan to the data |
| 58 | + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer |
| 59 | + tensor might result in a loss of precision. If this behaviour is |
| 60 | + unwanted, please make sure to provide a floating point input. |
56 | 61 |
|
57 | 62 | Parameters
|
58 | 63 | ----------
|
@@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
|
122 | 127 | if q is not None:
|
123 | 128 | arr.append(list_to_array(q))
|
124 | 129 | else:
|
125 |
| - q = unif(C2.shape[0], type_as=C2) |
| 130 | + q = unif(C2.shape[0], type_as=C1) |
126 | 131 | if G0 is not None:
|
127 | 132 | G0_ = G0
|
128 | 133 | arr.append(G0)
|
@@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
|
171 | 176 | else:
|
172 | 177 | def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
|
173 | 178 | return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
|
| 179 | + |
| 180 | + if not nx.is_floating_point(C10): |
| 181 | + warnings.warn( |
| 182 | + "Input structure matrix consists of integer. The transport plan will be " |
| 183 | + "casted accordingly, possibly resulting in a loss of precision. " |
| 184 | + "If this behaviour is unwanted, please make sure your input " |
| 185 | + "structure matrix consists of floating point elements.", |
| 186 | + stacklevel=2 |
| 187 | + ) |
| 188 | + |
174 | 189 | if log:
|
175 | 190 | res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
|
176 | 191 | log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
|
@@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
|
216 | 231 | which can lead to copy overhead on GPU arrays.
|
217 | 232 | .. note:: All computations in the conjugate gradient solver are done with
|
218 | 233 | numpy to limit memory overhead.
|
| 234 | + .. note:: This function will cast the computed transport plan to the data |
| 235 | + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer |
| 236 | + tensor might result in a loss of precision. If this behaviour is |
| 237 | + unwanted, please make sure to provide a floating point input. |
219 | 238 |
|
220 | 239 | Parameters
|
221 | 240 | ----------
|
@@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
|
286 | 305 | if p is None:
|
287 | 306 | p = unif(C1.shape[0], type_as=C1)
|
288 | 307 | if q is None:
|
289 |
| - q = unif(C2.shape[0], type_as=C2) |
| 308 | + q = unif(C2.shape[0], type_as=C1) |
290 | 309 |
|
291 | 310 | T, log_gw = gromov_wasserstein(
|
292 | 311 | C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
|
@@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
|
344 | 363 | which can lead to copy overhead on GPU arrays.
|
345 | 364 | .. note:: All computations in the conjugate gradient solver are done with
|
346 | 365 | numpy to limit memory overhead.
|
| 366 | + .. note:: This function will cast the computed transport plan to the data |
| 367 | + type of the provided input :math:`\mathbf{M}`. Casting to an integer |
| 368 | + tensor might result in a loss of precision. If this behaviour is |
| 369 | + unwanted, please make sure to provide a floating point input. |
347 | 370 |
|
348 | 371 |
|
349 | 372 | Parameters
|
@@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
|
409 | 432 | if p is not None:
|
410 | 433 | arr.append(list_to_array(p))
|
411 | 434 | else:
|
412 |
| - p = unif(C1.shape[0], type_as=C1) |
| 435 | + p = unif(C1.shape[0], type_as=M) |
413 | 436 | if q is not None:
|
414 | 437 | arr.append(list_to_array(q))
|
415 | 438 | else:
|
416 |
| - q = unif(C2.shape[0], type_as=C2) |
| 439 | + q = unif(C2.shape[0], type_as=M) |
417 | 440 | if G0 is not None:
|
418 | 441 | G0_ = G0
|
419 | 442 | arr.append(G0)
|
@@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
|
465 | 488 | else:
|
466 | 489 | def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
|
467 | 490 | return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
|
| 491 | + if not nx.is_floating_point(M0): |
| 492 | + warnings.warn( |
| 493 | + "Input feature matrix consists of integer. The transport plan will be " |
| 494 | + "casted accordingly, possibly resulting in a loss of precision. " |
| 495 | + "If this behaviour is unwanted, please make sure your input " |
| 496 | + "feature matrix consists of floating point elements.", |
| 497 | + stacklevel=2 |
| 498 | + ) |
468 | 499 | if log:
|
469 | 500 | res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
|
470 |
| - log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) |
471 |
| - log['u'] = nx.from_numpy(log['u'], type_as=C10) |
472 |
| - log['v'] = nx.from_numpy(log['v'], type_as=C10) |
473 |
| - return nx.from_numpy(res, type_as=C10), log |
| 501 | + log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0) |
| 502 | + log['u'] = nx.from_numpy(log['u'], type_as=M0) |
| 503 | + log['v'] = nx.from_numpy(log['v'], type_as=M0) |
| 504 | + return nx.from_numpy(res, type_as=M0), log |
474 | 505 | else:
|
475 |
| - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) |
| 506 | + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0) |
476 | 507 |
|
477 | 508 |
|
478 | 509 | def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
|
@@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
|
510 | 541 | which can lead to copy overhead on GPU arrays.
|
511 | 542 | .. note:: All computations in the conjugate gradient solver are done with
|
512 | 543 | numpy to limit memory overhead.
|
| 544 | + .. note:: This function will cast the computed transport plan to the data |
| 545 | + type of the provided input :math:`\mathbf{M}`. Casting to an integer |
| 546 | + tensor might result in a loss of precision. If this behaviour is |
| 547 | + unwanted, please make sure to provide a floating point input. |
513 | 548 |
|
514 | 549 | Parameters
|
515 | 550 | ----------
|
@@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
|
578 | 613 |
|
579 | 614 | # init marginals if set as None
|
580 | 615 | if p is None:
|
581 |
| - p = unif(C1.shape[0], type_as=C1) |
| 616 | + p = unif(C1.shape[0], type_as=M) |
582 | 617 | if q is None:
|
583 |
| - q = unif(C2.shape[0], type_as=C2) |
| 618 | + q = unif(C2.shape[0], type_as=M) |
584 | 619 |
|
585 | 620 | T, log_fgw = fused_gromov_wasserstein(
|
586 | 621 | M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
|
|
0 commit comments