diff --git a/RELEASES.md b/RELEASES.md index b21e5b0dc..cbb529852 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,10 @@ # Releases +## Next Release + +#### New features ++ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587) + ## 0.9.2dev #### New features diff --git a/ot/backend.py b/ot/backend.py index 36ea51373..7645c4237 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1043,6 +1043,14 @@ def matmul(self, a, b): """ raise NotImplementedError() + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + r""" + Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user. + + See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1392,6 +1400,9 @@ def detach(self, *args): def matmul(self, a, b): return np.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + _register_backend_implementation(NumpyBackend) @@ -1762,6 +1773,9 @@ def detach(self, *args): def matmul(self, a, b): return jnp.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + if jax: # Only register jax backend if it is installed @@ -2250,6 +2264,10 @@ def detach(self, *args): def matmul(self, a, b): return torch.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + out = None if copy else x + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out) + if torch: # Only register torch backend if it is installed @@ -2647,6 +2665,9 @@ def detach(self, *args): def matmul(self, a, b): return cp.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + if cp: # Only register cp backend if it is installed @@ -3070,6 +3091,12 @@ def detach(self, *args): def matmul(self, a, b): return tnp.matmul(a, b) + # todo(okachaiev): replace this with a more reasonable implementation + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + x = self.to_numpy(x) + x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + return self.from_numpy(x) + if tf: # Only register tensorflow backend if it is installed diff --git a/ot/da.py b/ot/da.py index 3628db51e..bb43623c4 100644 --- a/ot/da.py +++ b/ot/da.py @@ -18,7 +18,7 @@ from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import list_to_array, check_params, BaseEstimator, deprecated +from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array from .unbalanced import sinkhorn_unbalanced from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping from .optim import cg @@ -499,18 +499,12 @@ class label if self.limit_max != np.infty: self.limit_max = self.limit_max * nx.max(self.cost_) - # assumes labeled source samples occupy the first rows - # and labeled target samples occupy the first columns - classes = [c for c in nx.unique(ys) if c != -1] - for c in classes: - idx_s = nx.where((ys != c) & (ys != -1)) - idx_t = nx.where(yt == c) - - # all the coefficients corresponding to a source sample - # and a target sample : - # with different labels get a infinite - for j in idx_t[0]: - self.cost_[idx_s[0], j] = self.limit_max + # zeros where source label is missing (masked with -1) + missing_labels = ys + nx.ones(ys.shape, type_as=ys) + missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1) + # zeros where labels match + label_match = ys[:, None] - yt[None, :] + self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max) # distribution estimation self.mu_s = self.distribution_estimation(Xs) @@ -581,12 +575,11 @@ class label if check_params(Xs=Xs): if nx.array_equal(self.xs_, Xs): - # perform standard barycentric mapping transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs = nx.dot(transp, self.xt_) @@ -604,9 +597,8 @@ class label idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / nx.sum( - self.coupling_, axis=1)[:, None] - transp[~ nx.isfinite(transp)] = 0 + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points @@ -645,23 +637,16 @@ def transform_labels(self, ys=None): # check the necessary inputs parameters are here if check_params(ys=ys): - - ysTemp = label_normalization(nx.copy(ys)) - classes = nx.unique(ysTemp) - n = len(classes) - D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) - # perform label propagation transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 - - for c in classes: - D1[int(c), ysTemp == c] = 1 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels - transp_ys = nx.dot(D1, transp) + labels = label_normalization(ys) + masks = labels_to_masks(labels, nx=nx, type_as=transp) + transp_ys = nx.dot(masks.T, transp) return transp_ys.T @@ -697,12 +682,11 @@ class label if check_params(Xt=Xt): if nx.array_equal(self.xt_, Xt): - # perform standard barycentric mapping transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - transp_[~ nx.isfinite(transp_)] = 0 + transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xt = nx.dot(transp_, self.xs_) @@ -719,9 +703,8 @@ class label idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / nx.sum( - self.coupling_, 0)[:, None] - transp_[~ nx.isfinite(transp_)] = 0 + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] + transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0) transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points @@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None): # check the necessary inputs parameters are here if check_params(yt=yt): - - ytTemp = label_normalization(nx.copy(yt)) - classes = nx.unique(ytTemp) - n = len(classes) - D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) - # perform label propagation transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] - # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) - for c in classes: - D1[int(c), ytTemp == c] = 1 - - # compute propagated samples - transp_ys = nx.dot(D1, transp.T) + # compute propagated labels + labels = label_normalization(yt) + masks = labels_to_masks(labels, nx=nx, type_as=transp) + transp_ys = nx.dot(masks.T, transp.T) return transp_ys.T @@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None): type_as=ys[0] ) for i in range(len(ys)): - ysTemp = label_normalization(nx.copy(ys[i])) + ysTemp = label_normalization(ys[i]) classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) @@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None): # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(nx.copy(yt)) + ytTemp = label_normalization(yt) classes = nx.unique(ytTemp) n = len(classes) D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) diff --git a/ot/utils.py b/ot/utils.py index cb29b21c9..19e61f1fe 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -390,7 +390,7 @@ def is_all_finite(*args): return all(not nx.any(~nx.isfinite(arg)) for arg in args) -def label_normalization(y, start=0): +def label_normalization(y, start=0, nx=None): r""" Transform labels to start at a given value Parameters @@ -399,18 +399,45 @@ def label_normalization(y, start=0): The vector of labels to be normalized. start : int Desired value for the smallest label in :math:`\mathbf{y}` (default=0) + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `y`. Returns ------- y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ - nx = get_backend(y) + if nx is None: + nx = get_backend(y) + diff = nx.min(y) - start + return y if diff == 0 else (y - diff) + + +def labels_to_masks(y, type_as=None, nx=None): + r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks. + + Parameters + ---------- + y : array-like, shape (n_samples, ) + The vector of labels. + type_as : array_like + Array of the same type of the expected output. + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `y`. - diff = nx.min(nx.unique(y)) - start - if diff != 0: - y -= diff - return y + Returns + ------- + masks : array-like, shape (n_samples, n_labels) + The (n_samples, n_labels) matrix of label masks. + """ + if nx is None: + nx = get_backend(y) + if type_as is None: + type_as = y + labels_u, labels_idx = nx.unique(y, return_inverse=True) + n_labels = labels_u.shape[0] + masks = nx.eye(n_labels, type_as=type_as)[labels_idx] + return masks def parmap(f, X, nprocs="default"): @@ -755,10 +782,8 @@ def _get_backend(self, *arrays): nx = get_backend( *[input_ for input_ in arrays if input_ is not None] ) - if nx.__name__ in ("jax", "tf"): - raise TypeError( - """JAX or TF arrays have been received but domain - adaptation does not support those backend.""") + if nx.__name__ in ("tf",): + raise TypeError("Domain adaptation does not support TF backend.") self.nx = nx return nx diff --git a/test/test_backend.py b/test/test_backend.py index 605e30ad8..3bc1e5480 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -264,6 +264,8 @@ def test_empty_backend(): nx.detach(M) with pytest.raises(NotImplementedError): nx.matmul(M, M.T) + with pytest.raises(NotImplementedError): + nx.nan_to_num(M) def test_func_backends(nx): @@ -667,6 +669,11 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matmul broadcast") + vec = nx.from_numpy(np.array([1, np.nan, -1])) + vec = nx.nan_to_num(vec, nan=0) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("nan_to_num") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_da.py b/test/test_da.py index 8f248c484..0ef5db79e 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -7,7 +7,6 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal import pytest -import warnings import ot from ot.datasets import make_data_classif @@ -158,7 +157,6 @@ def test_sinkhorn_lpl1_transport_class(nx): assert mass_semi == 0, "semisupervised mode not working" -@pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_l1l2_transport_class(nx): """test_sinkhorn_transport @@ -169,15 +167,16 @@ def test_sinkhorn_l1l2_transport_class(nx): Xs, ys = make_data_classif('3gauss', ns, random_state=42) Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + # prepare semi-supervised labels + yt_semi = np.copy(yt) + yt_semi[np.arange(0, nt, 2)] = -1 - Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500) + otda.fit(Xs=Xs, ys=ys, Xt=Xt) # test its computed - with warnings.catch_warnings(): - warnings.simplefilter("error") - otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert hasattr(otda, "cost_") assert hasattr(otda, "coupling_") assert hasattr(otda, "log_") @@ -234,7 +233,7 @@ def test_sinkhorn_l1l2_transport_class(nx): n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornL1l2Transport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) @@ -243,11 +242,9 @@ def test_sinkhorn_l1l2_transport_class(nx): # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = nx.sum( - otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) + mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), - rtol=1e-9, atol=1e-9) + assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9) # check everything runs well with log=True otda = ot.da.SinkhornL1l2Transport(log=True) diff --git a/test/test_utils.py b/test/test_utils.py index 258a1c742..6cdb7ead7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -583,3 +583,28 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + + +def test_labels_to_mask_helper(nx): + y = np.array([1, 0, 2, 2, 1]) + out = np.array([ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + [0, 0, 1], + [0, 1, 0], + ]) + y = nx.from_numpy(y) + masks = ot.utils.labels_to_masks(y) + np.testing.assert_array_equal(out, masks) + + +def test_label_normalization(nx): + y = nx.from_numpy(np.arange(5) + 1) + out = np.arange(5) + # labels are shifted + y_normalized = ot.utils.label_normalization(y) + np.testing.assert_array_equal(out, y_normalized) + # labels are shifted but the shift if expected + y_normalized_start = ot.utils.label_normalization(y, start=1) + np.testing.assert_array_equal(y, y_normalized_start)