Skip to content

Commit cb29080

Browse files
authored
Merge branch 'master' into pr_r0.9.2
2 parents e845c11 + 9ddb690 commit cb29080

File tree

11 files changed

+138
-76
lines changed

11 files changed

+138
-76
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2016 Rémi Flamary
3+
Copyright (c) 2016-2023 POT contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,16 @@ The examples folder contain several examples and use case for the library. The f
185185

186186
## Acknowledgements
187187

188-
This toolbox has been created and is maintained by
188+
This toolbox has been created by
189189

190-
* [Rémi Flamary](http://remi.flamary.com/)
190+
* [Rémi Flamary](https://remi.flamary.com/)
191191
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
192192

193+
It is currently maintained by
194+
195+
* [Rémi Flamary](https://remi.flamary.com/)
196+
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
197+
193198
The numerous contributors to this library are listed [here](CONTRIBUTORS.md).
194199

195200
POT has benefited from the financing or manpower from the following partners:

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ We also fixed a number of issues, the most pressing being a problem of GPU memor
6767
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
6868
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
6969
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)
70+
+ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587)
7071
+ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568)
7172

7273

docs/source/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def __getattr__(cls, name):
9999

100100
# General information about the project.
101101
project = u'POT Python Optimal Transport'
102-
copyright = u'2016-2021, Rémi Flamary, Nicolas Courty'
103-
author = u'Rémi Flamary, Nicolas Courty'
102+
copyright = u'2016-2023, POT Contributors'
103+
author = u'Rémi Flamary, POT Contributors'
104104

105105
# The version info for the project you're documenting, acts as replacement for
106106
# |version| and |release|, also used in various other places throughout the

ot/backend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,14 @@ def matmul(self, a, b):
10431043
"""
10441044
raise NotImplementedError()
10451045

1046+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1047+
r"""
1048+
Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.
1049+
1050+
See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num
1051+
"""
1052+
raise NotImplementedError()
1053+
10461054

10471055
class NumpyBackend(Backend):
10481056
"""
@@ -1392,6 +1400,9 @@ def detach(self, *args):
13921400
def matmul(self, a, b):
13931401
return np.matmul(a, b)
13941402

1403+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1404+
return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
1405+
13951406

13961407
_register_backend_implementation(NumpyBackend)
13971408

@@ -1762,6 +1773,9 @@ def detach(self, *args):
17621773
def matmul(self, a, b):
17631774
return jnp.matmul(a, b)
17641775

1776+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1777+
return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
1778+
17651779

17661780
if jax:
17671781
# Only register jax backend if it is installed
@@ -2250,6 +2264,10 @@ def detach(self, *args):
22502264
def matmul(self, a, b):
22512265
return torch.matmul(a, b)
22522266

2267+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
2268+
out = None if copy else x
2269+
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out)
2270+
22532271

22542272
if torch:
22552273
# Only register torch backend if it is installed
@@ -2647,6 +2665,9 @@ def detach(self, *args):
26472665
def matmul(self, a, b):
26482666
return cp.matmul(a, b)
26492667

2668+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
2669+
return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
2670+
26502671

26512672
if cp:
26522673
# Only register cp backend if it is installed
@@ -3070,6 +3091,12 @@ def detach(self, *args):
30703091
def matmul(self, a, b):
30713092
return tnp.matmul(a, b)
30723093

3094+
# todo(okachaiev): replace this with a more reasonable implementation
3095+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
3096+
x = self.to_numpy(x)
3097+
x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
3098+
return self.from_numpy(x)
3099+
30733100

30743101
if tf:
30753102
# Only register tensorflow backend if it is installed

ot/da.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .bregman import sinkhorn, jcpot_barycenter
1919
from .lp import emd
2020
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
21-
from .utils import list_to_array, check_params, BaseEstimator, deprecated
21+
from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array
2222
from .unbalanced import sinkhorn_unbalanced
2323
from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping
2424
from .optim import cg
@@ -499,18 +499,12 @@ class label
499499
if self.limit_max != np.infty:
500500
self.limit_max = self.limit_max * nx.max(self.cost_)
501501

502-
# assumes labeled source samples occupy the first rows
503-
# and labeled target samples occupy the first columns
504-
classes = [c for c in nx.unique(ys) if c != -1]
505-
for c in classes:
506-
idx_s = nx.where((ys != c) & (ys != -1))
507-
idx_t = nx.where(yt == c)
508-
509-
# all the coefficients corresponding to a source sample
510-
# and a target sample :
511-
# with different labels get a infinite
512-
for j in idx_t[0]:
513-
self.cost_[idx_s[0], j] = self.limit_max
502+
# zeros where source label is missing (masked with -1)
503+
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
504+
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
505+
# zeros where labels match
506+
label_match = ys[:, None] - yt[None, :]
507+
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)
514508

515509
# distribution estimation
516510
self.mu_s = self.distribution_estimation(Xs)
@@ -581,12 +575,11 @@ class label
581575
if check_params(Xs=Xs):
582576

583577
if nx.array_equal(self.xs_, Xs):
584-
585578
# perform standard barycentric mapping
586579
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
587580

588581
# set nans to 0
589-
transp[~ nx.isfinite(transp)] = 0
582+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
590583

591584
# compute transported samples
592585
transp_Xs = nx.dot(transp, self.xt_)
@@ -604,9 +597,8 @@ class label
604597
idx = nx.argmin(D0, axis=1)
605598

606599
# transport the source samples
607-
transp = self.coupling_ / nx.sum(
608-
self.coupling_, axis=1)[:, None]
609-
transp[~ nx.isfinite(transp)] = 0
600+
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
601+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
610602
transp_Xs_ = nx.dot(transp, self.xt_)
611603

612604
# define the transported points
@@ -645,23 +637,16 @@ def transform_labels(self, ys=None):
645637

646638
# check the necessary inputs parameters are here
647639
if check_params(ys=ys):
648-
649-
ysTemp = label_normalization(nx.copy(ys))
650-
classes = nx.unique(ysTemp)
651-
n = len(classes)
652-
D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)
653-
654640
# perform label propagation
655641
transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]
656642

657643
# set nans to 0
658-
transp[~ nx.isfinite(transp)] = 0
659-
660-
for c in classes:
661-
D1[int(c), ysTemp == c] = 1
644+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
662645

663646
# compute propagated labels
664-
transp_ys = nx.dot(D1, transp)
647+
labels = label_normalization(ys)
648+
masks = labels_to_masks(labels, nx=nx, type_as=transp)
649+
transp_ys = nx.dot(masks.T, transp)
665650

666651
return transp_ys.T
667652

@@ -697,12 +682,11 @@ class label
697682
if check_params(Xt=Xt):
698683

699684
if nx.array_equal(self.xt_, Xt):
700-
701685
# perform standard barycentric mapping
702686
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
703687

704688
# set nans to 0
705-
transp_[~ nx.isfinite(transp_)] = 0
689+
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
706690

707691
# compute transported samples
708692
transp_Xt = nx.dot(transp_, self.xs_)
@@ -719,9 +703,8 @@ class label
719703
idx = nx.argmin(D0, axis=1)
720704

721705
# transport the target samples
722-
transp_ = self.coupling_.T / nx.sum(
723-
self.coupling_, 0)[:, None]
724-
transp_[~ nx.isfinite(transp_)] = 0
706+
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
707+
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
725708
transp_Xt_ = nx.dot(transp_, self.xs_)
726709

727710
# define the transported points
@@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None):
750733

751734
# check the necessary inputs parameters are here
752735
if check_params(yt=yt):
753-
754-
ytTemp = label_normalization(nx.copy(yt))
755-
classes = nx.unique(ytTemp)
756-
n = len(classes)
757-
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)
758-
759736
# perform label propagation
760737
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
761-
762738
# set nans to 0
763-
transp[~ nx.isfinite(transp)] = 0
739+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
764740

765-
for c in classes:
766-
D1[int(c), ytTemp == c] = 1
767-
768-
# compute propagated samples
769-
transp_ys = nx.dot(D1, transp.T)
741+
# compute propagated labels
742+
labels = label_normalization(yt)
743+
masks = labels_to_masks(labels, nx=nx, type_as=transp)
744+
transp_ys = nx.dot(masks.T, transp.T)
770745

771746
return transp_ys.T
772747

@@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None):
21512126
type_as=ys[0]
21522127
)
21532128
for i in range(len(ys)):
2154-
ysTemp = label_normalization(nx.copy(ys[i]))
2129+
ysTemp = label_normalization(ys[i])
21552130
classes = nx.unique(ysTemp)
21562131
n = len(classes)
21572132
ns = len(ysTemp)
@@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
21942169
# check the necessary inputs parameters are here
21952170
if check_params(yt=yt):
21962171
transp_ys = []
2197-
ytTemp = label_normalization(nx.copy(yt))
2172+
ytTemp = label_normalization(yt)
21982173
classes = nx.unique(ytTemp)
21992174
n = len(classes)
22002175
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])

ot/utils.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def is_all_finite(*args):
390390
return all(not nx.any(~nx.isfinite(arg)) for arg in args)
391391

392392

393-
def label_normalization(y, start=0):
393+
def label_normalization(y, start=0, nx=None):
394394
r""" Transform labels to start at a given value
395395
396396
Parameters
@@ -399,18 +399,45 @@ def label_normalization(y, start=0):
399399
The vector of labels to be normalized.
400400
start : int
401401
Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
402+
nx : Backend, optional
403+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
402404
403405
Returns
404406
-------
405407
y : array-like, shape (`n1`, )
406408
The input vector of labels normalized according to given start value.
407409
"""
408-
nx = get_backend(y)
410+
if nx is None:
411+
nx = get_backend(y)
412+
diff = nx.min(y) - start
413+
return y if diff == 0 else (y - diff)
414+
415+
416+
def labels_to_masks(y, type_as=None, nx=None):
417+
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.
418+
419+
Parameters
420+
----------
421+
y : array-like, shape (n_samples, )
422+
The vector of labels.
423+
type_as : array_like
424+
Array of the same type of the expected output.
425+
nx : Backend, optional
426+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
409427
410-
diff = nx.min(nx.unique(y)) - start
411-
if diff != 0:
412-
y -= diff
413-
return y
428+
Returns
429+
-------
430+
masks : array-like, shape (n_samples, n_labels)
431+
The (n_samples, n_labels) matrix of label masks.
432+
"""
433+
if nx is None:
434+
nx = get_backend(y)
435+
if type_as is None:
436+
type_as = y
437+
labels_u, labels_idx = nx.unique(y, return_inverse=True)
438+
n_labels = labels_u.shape[0]
439+
masks = nx.eye(n_labels, type_as=type_as)[labels_idx]
440+
return masks
414441

415442

416443
def parmap(f, X, nprocs="default"):
@@ -755,10 +782,8 @@ def _get_backend(self, *arrays):
755782
nx = get_backend(
756783
*[input_ for input_ in arrays if input_ is not None]
757784
)
758-
if nx.__name__ in ("jax", "tf"):
759-
raise TypeError(
760-
"""JAX or TF arrays have been received but domain
761-
adaptation does not support those backend.""")
785+
if nx.__name__ in ("tf",):
786+
raise TypeError("Domain adaptation does not support TF backend.")
762787
self.nx = nx
763788
return nx
764789

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
description='Python Optimal Transport Library',
5252
long_description=README,
5353
long_description_content_type='text/markdown',
54-
author=u'Remi Flamary, Nicolas Courty',
54+
author=u'Remi Flamary, Nicolas Courty, POT Contributors',
5555
author_email='remi.flamary@gmail.com, ncourty@gmail.com',
5656
url='https://github.com/PythonOT/POT',
5757
packages=find_packages(exclude=["benchmarks"]),

test/test_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def test_empty_backend():
264264
nx.detach(M)
265265
with pytest.raises(NotImplementedError):
266266
nx.matmul(M, M.T)
267+
with pytest.raises(NotImplementedError):
268+
nx.nan_to_num(M)
267269

268270

269271
def test_func_backends(nx):
@@ -667,6 +669,11 @@ def test_func_backends(nx):
667669
lst_b.append(nx.to_numpy(A))
668670
lst_name.append("matmul broadcast")
669671

672+
vec = nx.from_numpy(np.array([1, np.nan, -1]))
673+
vec = nx.nan_to_num(vec, nan=0)
674+
lst_b.append(nx.to_numpy(vec))
675+
lst_name.append("nan_to_num")
676+
670677
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
671678
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
672679
assert not nx.array_equal(

0 commit comments

Comments
 (0)