From 4ba2e901e69743a2fa33b256c63b2fb671bf9dae Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 2 Dec 2023 22:51:03 +0100 Subject: [PATCH 01/14] Draft sinkhorn_l1l2_transport to work on JAX --- ot/da.py | 59 +++++++++++++++++++++---------------------------- test/test_da.py | 1 - 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/ot/da.py b/ot/da.py index 3628db51e..a900e4e8c 100644 --- a/ot/da.py +++ b/ot/da.py @@ -499,18 +499,13 @@ class label if self.limit_max != np.infty: self.limit_max = self.limit_max * nx.max(self.cost_) - # assumes labeled source samples occupy the first rows - # and labeled target samples occupy the first columns - classes = [c for c in nx.unique(ys) if c != -1] - for c in classes: - idx_s = nx.where((ys != c) & (ys != -1)) - idx_t = nx.where(yt == c) - - # all the coefficients corresponding to a source sample - # and a target sample : - # with different labels get a infinite - for j in idx_t[0]: - self.cost_[idx_s[0], j] = self.limit_max + # xxx(okachaiev): add "ones_like"? + missing_labels = ys + nx.ones(ys.shape, type_as=ys) + # xxx(okachaiev): i guess we need better tests for the use case of -1 labels + missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1) + label_match = nx.repeat(ys[:, None], ys.shape[0], 1) - nx.repeat(yt[None, :], yt.shape[0], 0) + # xxx(okachaiev): can we have negative cost? + self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max) # distribution estimation self.mu_s = self.distribution_estimation(Xs) @@ -586,6 +581,7 @@ class label transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 + # xxx(okachaiev): replace with nan_to_num (add backend function) transp[~ nx.isfinite(transp)] = 0 # compute transported samples @@ -606,6 +602,7 @@ class label # transport the source samples transp = self.coupling_ / nx.sum( self.coupling_, axis=1)[:, None] + # xxx(okachaiev): replace with nan_to_num (add backend function) transp[~ nx.isfinite(transp)] = 0 transp_Xs_ = nx.dot(transp, self.xt_) @@ -645,26 +642,24 @@ def transform_labels(self, ys=None): # check the necessary inputs parameters are here if check_params(ys=ys): - - ysTemp = label_normalization(nx.copy(ys)) - classes = nx.unique(ysTemp) - n = len(classes) - D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) - # perform label propagation transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 + # xxx(okachaiev): replace with nan_to_nums transp[~ nx.isfinite(transp)] = 0 - for c in classes: - D1[int(c), ysTemp == c] = 1 + ysTemp = label_normalization(nx.copy(ys)) + labels_u, labels_idx = nx.unique(ysTemp, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=transp)[None, labels_idx].squeeze(0) # compute propagated labels - transp_ys = nx.dot(D1, transp) + transp_ys = nx.dot(unroll_labels_idx.T, transp) return transp_ys.T + # xxx(okachaiev): seems like a lot of code duplication def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` @@ -697,11 +692,11 @@ class label if check_params(Xt=Xt): if nx.array_equal(self.xt_, Xt): - # perform standard barycentric mapping transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 + # xxx(okachaiev): replace with nan_to_nums transp_[~ nx.isfinite(transp_)] = 0 # compute transported samples @@ -721,6 +716,7 @@ class label # transport the target samples transp_ = self.coupling_.T / nx.sum( self.coupling_, 0)[:, None] + # xxx(okachaiev): replace with nan_to_nums transp_[~ nx.isfinite(transp_)] = 0 transp_Xt_ = nx.dot(transp_, self.xs_) @@ -750,25 +746,20 @@ def inverse_transform_labels(self, yt=None): # check the necessary inputs parameters are here if check_params(yt=yt): - - ytTemp = label_normalization(nx.copy(yt)) - classes = nx.unique(ytTemp) - n = len(classes) - D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) - # perform label propagation transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] - # set nans to 0 transp[~ nx.isfinite(transp)] = 0 - for c in classes: - D1[int(c), ytTemp == c] = 1 + ytTemp = label_normalization(nx.copy(yt)) + # xxx(okachaiev): move this to a helper? + labels_u, labels_idx = nx.unique(ytTemp, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=transp)[None, labels_idx].squeeze(0) # compute propagated samples - transp_ys = nx.dot(D1, transp.T) - - return transp_ys.T + transp_ys = nx.dot(unroll_labels_idx, transp) + return transp_ys class LinearTransport(BaseTransport): diff --git a/test/test_da.py b/test/test_da.py index 8f248c484..4dfbff14d 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -158,7 +158,6 @@ def test_sinkhorn_lpl1_transport_class(nx): assert mass_semi == 0, "semisupervised mode not working" -@pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_l1l2_transport_class(nx): """test_sinkhorn_transport From f4e008100f81eb96c435c27696f825c3a4038f83 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 12:25:23 +0100 Subject: [PATCH 02/14] Move label_to_masks in utils --- ot/da.py | 30 ++++++++++++------------------ ot/utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/ot/da.py b/ot/da.py index a900e4e8c..9d4525ea4 100644 --- a/ot/da.py +++ b/ot/da.py @@ -18,7 +18,7 @@ from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import list_to_array, check_params, BaseEstimator, deprecated +from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array from .unbalanced import sinkhorn_unbalanced from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping from .optim import cg @@ -499,12 +499,12 @@ class label if self.limit_max != np.infty: self.limit_max = self.limit_max * nx.max(self.cost_) - # xxx(okachaiev): add "ones_like"? + # zeros where source label is missing (masked with -1) missing_labels = ys + nx.ones(ys.shape, type_as=ys) # xxx(okachaiev): i guess we need better tests for the use case of -1 labels missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1) - label_match = nx.repeat(ys[:, None], ys.shape[0], 1) - nx.repeat(yt[None, :], yt.shape[0], 0) - # xxx(okachaiev): can we have negative cost? + # zeros where labels match + label_match = ys[:, None] - yt[None, :] self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max) # distribution estimation @@ -649,13 +649,10 @@ def transform_labels(self, ys=None): # xxx(okachaiev): replace with nan_to_nums transp[~ nx.isfinite(transp)] = 0 - ysTemp = label_normalization(nx.copy(ys)) - labels_u, labels_idx = nx.unique(ysTemp, return_inverse=True) - n_labels = labels_u.shape[0] - unroll_labels_idx = nx.eye(n_labels, type_as=transp)[None, labels_idx].squeeze(0) - # compute propagated labels - transp_ys = nx.dot(unroll_labels_idx.T, transp) + labels = label_normalization(nx.copy(ys)) + masks = labels_to_masks(labels, nx=nx, type_as=transp) + transp_ys = nx.dot(masks.T, transp) return transp_ys.T @@ -751,15 +748,12 @@ def inverse_transform_labels(self, yt=None): # set nans to 0 transp[~ nx.isfinite(transp)] = 0 - ytTemp = label_normalization(nx.copy(yt)) - # xxx(okachaiev): move this to a helper? - labels_u, labels_idx = nx.unique(ytTemp, return_inverse=True) - n_labels = labels_u.shape[0] - unroll_labels_idx = nx.eye(n_labels, type_as=transp)[None, labels_idx].squeeze(0) - # compute propagated samples - transp_ys = nx.dot(unroll_labels_idx, transp) - return transp_ys + labels = label_normalization(nx.copy(yt)) + masks = labels_to_masks(labels, nx=nx, type_as=transp).T + transp_ys = nx.dot(masks, transp.T) + + return transp_ys.T class LinearTransport(BaseTransport): diff --git a/ot/utils.py b/ot/utils.py index cb29b21c9..469483256 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -413,6 +413,33 @@ def label_normalization(y, start=0): return y +def labels_to_masks(y, nx=None, type_as=None): + r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks. + + Parameters + ---------- + y : array-like, shape (n_samples, ) + The vector of labels. + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `y`. + type_as : array_like + Array of the same type of the expected output. + + Returns + ------- + masks : array-like, shape (n_samples, n_labels) + The (n_samples, n_labels) matrix of label masks. + """ + if nx is None: + nx = get_backend(y) + if type_as is None: + type_as = y + labels_u, labels_idx = nx.unique(y, return_inverse=True) + n_labels = labels_u.shape[0] + masks = nx.eye(n_labels, type_as=type_as)[None, labels_idx].squeeze(0) + return masks + + def parmap(f, X, nprocs="default"): r""" parallel map for multiprocessing. The function has been deprecated and only performs a regular map. From 86961c0c7198d3f60cea7566537d3c030c93d51d Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 12:45:34 +0100 Subject: [PATCH 03/14] Move nan_to_num to backend --- ot/backend.py | 27 +++++++++++++++++++++++++++ ot/da.py | 29 ++++++++++------------------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 36ea51373..7645c4237 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1043,6 +1043,14 @@ def matmul(self, a, b): """ raise NotImplementedError() + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + r""" + Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user. + + See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1392,6 +1400,9 @@ def detach(self, *args): def matmul(self, a, b): return np.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + _register_backend_implementation(NumpyBackend) @@ -1762,6 +1773,9 @@ def detach(self, *args): def matmul(self, a, b): return jnp.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + if jax: # Only register jax backend if it is installed @@ -2250,6 +2264,10 @@ def detach(self, *args): def matmul(self, a, b): return torch.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + out = None if copy else x + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out) + if torch: # Only register torch backend if it is installed @@ -2647,6 +2665,9 @@ def detach(self, *args): def matmul(self, a, b): return cp.matmul(a, b) + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + if cp: # Only register cp backend if it is installed @@ -3070,6 +3091,12 @@ def detach(self, *args): def matmul(self, a, b): return tnp.matmul(a, b) + # todo(okachaiev): replace this with a more reasonable implementation + def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): + x = self.to_numpy(x) + x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + return self.from_numpy(x) + if tf: # Only register tensorflow backend if it is installed diff --git a/ot/da.py b/ot/da.py index 9d4525ea4..687bec944 100644 --- a/ot/da.py +++ b/ot/da.py @@ -576,13 +576,11 @@ class label if check_params(Xs=Xs): if nx.array_equal(self.xs_, Xs): - # perform standard barycentric mapping transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - # xxx(okachaiev): replace with nan_to_num (add backend function) - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs = nx.dot(transp, self.xt_) @@ -600,10 +598,8 @@ class label idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / nx.sum( - self.coupling_, axis=1)[:, None] - # xxx(okachaiev): replace with nan_to_num (add backend function) - transp[~ nx.isfinite(transp)] = 0 + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points @@ -646,8 +642,7 @@ def transform_labels(self, ys=None): transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - # xxx(okachaiev): replace with nan_to_nums - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels labels = label_normalization(nx.copy(ys)) @@ -656,7 +651,6 @@ def transform_labels(self, ys=None): return transp_ys.T - # xxx(okachaiev): seems like a lot of code duplication def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` @@ -693,8 +687,7 @@ class label transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - # xxx(okachaiev): replace with nan_to_nums - transp_[~ nx.isfinite(transp_)] = 0 + transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xt = nx.dot(transp_, self.xs_) @@ -711,10 +704,8 @@ class label idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / nx.sum( - self.coupling_, 0)[:, None] - # xxx(okachaiev): replace with nan_to_nums - transp_[~ nx.isfinite(transp_)] = 0 + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] + transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0) transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points @@ -746,12 +737,12 @@ def inverse_transform_labels(self, yt=None): # perform label propagation transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated samples labels = label_normalization(nx.copy(yt)) - masks = labels_to_masks(labels, nx=nx, type_as=transp).T - transp_ys = nx.dot(masks, transp.T) + masks = labels_to_masks(labels, nx=nx, type_as=transp) + transp_ys = nx.dot(masks.T, transp.T) return transp_ys.T From 2132f03b40d03f11966642e0f8e2707687968216 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:00:40 +0100 Subject: [PATCH 04/14] Proper test case for semi-supervised DA --- ot/da.py | 1 - test/test_da.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ot/da.py b/ot/da.py index 687bec944..7bd71c5f4 100644 --- a/ot/da.py +++ b/ot/da.py @@ -501,7 +501,6 @@ class label # zeros where source label is missing (masked with -1) missing_labels = ys + nx.ones(ys.shape, type_as=ys) - # xxx(okachaiev): i guess we need better tests for the use case of -1 labels missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1) # zeros where labels match label_match = ys[:, None] - yt[None, :] diff --git a/test/test_da.py b/test/test_da.py index 4dfbff14d..b5e5a838a 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -168,8 +168,11 @@ def test_sinkhorn_l1l2_transport_class(nx): Xs, ys = make_data_classif('3gauss', ns, random_state=42) Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + # prepare semi-supervised labels + yt_semi = np.copy(yt) + yt_semi[np.arange(0, nt, 2)] = -1 - Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500) @@ -233,7 +236,7 @@ def test_sinkhorn_l1l2_transport_class(nx): n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornL1l2Transport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) @@ -242,11 +245,9 @@ def test_sinkhorn_l1l2_transport_class(nx): # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = nx.sum( - otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) + mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), - rtol=1e-9, atol=1e-9) + assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9) # check everything runs well with log=True otda = ot.da.SinkhornL1l2Transport(log=True) From f6b5042a7bcfea1ac5ee27abf378368fbefb7f42 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:06:28 +0100 Subject: [PATCH 05/14] Test case for label to mask computation --- ot/utils.py | 2 +- test/test_utils.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index 469483256..865cb1cac 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -415,7 +415,7 @@ def label_normalization(y, start=0): def labels_to_masks(y, nx=None, type_as=None): r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks. - + Parameters ---------- y : array-like, shape (n_samples, ) diff --git a/test/test_utils.py b/test/test_utils.py index 258a1c742..b4154ee13 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -583,3 +583,17 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + + +def test_label_to_mask_helper(nx): + y = np.array([1, 0, 2, 2, 1]) + out = np.array([ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + [0, 0, 1], + [0, 1, 0], + ]) + y = nx.from_numpy(y) + masks = ot.utils.labels_to_masks(y) + np.testing.assert_array_equal(out, masks) From 4f1249580611e21e9d8920da5331402430a8eec5 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:12:59 +0100 Subject: [PATCH 06/14] Simplified axis operations for labels --- ot/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index 865cb1cac..baf34b49d 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -436,7 +436,7 @@ def labels_to_masks(y, nx=None, type_as=None): type_as = y labels_u, labels_idx = nx.unique(y, return_inverse=True) n_labels = labels_u.shape[0] - masks = nx.eye(n_labels, type_as=type_as)[None, labels_idx].squeeze(0) + masks = nx.eye(n_labels, type_as=type_as)[labels_idx] return masks From d969b24eb6b759dcc68bde36d52efb24ee2fcd17 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:23:58 +0100 Subject: [PATCH 07/14] Allow JAX backend for BaseEstimator --- ot/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index baf34b49d..d712abe57 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -782,10 +782,8 @@ def _get_backend(self, *arrays): nx = get_backend( *[input_ for input_ in arrays if input_ is not None] ) - if nx.__name__ in ("jax", "tf"): - raise TypeError( - """JAX or TF arrays have been received but domain - adaptation does not support those backend.""") + if nx.__name__ in ("tf",): + raise TypeError("Domain adaptation does not support TF backend.") self.nx = nx return nx From fe8a7f07bf26aa267bce45a4cd6b85dff09a04a7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:31:16 +0100 Subject: [PATCH 08/14] Label normalization performs copy only when necessary --- ot/da.py | 8 ++++---- ot/utils.py | 20 ++++++++++---------- test/test_utils.py | 13 ++++++++++++- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/ot/da.py b/ot/da.py index 7bd71c5f4..ce48b47b1 100644 --- a/ot/da.py +++ b/ot/da.py @@ -644,7 +644,7 @@ def transform_labels(self, ys=None): transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels - labels = label_normalization(nx.copy(ys)) + labels = label_normalization(ys) masks = labels_to_masks(labels, nx=nx, type_as=transp) transp_ys = nx.dot(masks.T, transp) @@ -739,7 +739,7 @@ def inverse_transform_labels(self, yt=None): transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated samples - labels = label_normalization(nx.copy(yt)) + labels = label_normalization(yt) masks = labels_to_masks(labels, nx=nx, type_as=transp) transp_ys = nx.dot(masks.T, transp.T) @@ -2126,7 +2126,7 @@ def transform_labels(self, ys=None): type_as=ys[0] ) for i in range(len(ys)): - ysTemp = label_normalization(nx.copy(ys[i])) + ysTemp = label_normalization(ys[i]) classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) @@ -2169,7 +2169,7 @@ def inverse_transform_labels(self, yt=None): # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(nx.copy(yt)) + ytTemp = label_normalization(yt) classes = nx.unique(ytTemp) n = len(classes) D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) diff --git a/ot/utils.py b/ot/utils.py index d712abe57..9de6e06cd 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -390,7 +390,7 @@ def is_all_finite(*args): return all(not nx.any(~nx.isfinite(arg)) for arg in args) -def label_normalization(y, start=0): +def label_normalization(y, start=0, nx=None): r""" Transform labels to start at a given value Parameters @@ -399,31 +399,31 @@ def label_normalization(y, start=0): The vector of labels to be normalized. start : int Desired value for the smallest label in :math:`\mathbf{y}` (default=0) + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `y`. Returns ------- y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ - nx = get_backend(y) - + if nx is None: + nx = get_backend(y) diff = nx.min(nx.unique(y)) - start - if diff != 0: - y -= diff - return y + return y if diff == 0 else (y - diff) -def labels_to_masks(y, nx=None, type_as=None): - r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks. +def labels_to_masks(y, type_as=None, nx=None): + r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks. Parameters ---------- y : array-like, shape (n_samples, ) The vector of labels. - nx : Backend, optional - Backend to perform computations on. If omitted, the backend defaults to that of `y`. type_as : array_like Array of the same type of the expected output. + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `y`. Returns ------- diff --git a/test/test_utils.py b/test/test_utils.py index b4154ee13..6cdb7ead7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -585,7 +585,7 @@ def test_lowrank_LazyTensor(nx): np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) -def test_label_to_mask_helper(nx): +def test_labels_to_mask_helper(nx): y = np.array([1, 0, 2, 2, 1]) out = np.array([ [0, 1, 0], @@ -597,3 +597,14 @@ def test_label_to_mask_helper(nx): y = nx.from_numpy(y) masks = ot.utils.labels_to_masks(y) np.testing.assert_array_equal(out, masks) + + +def test_label_normalization(nx): + y = nx.from_numpy(np.arange(5) + 1) + out = np.arange(5) + # labels are shifted + y_normalized = ot.utils.label_normalization(y) + np.testing.assert_array_equal(out, y_normalized) + # labels are shifted but the shift if expected + y_normalized_start = ot.utils.label_normalization(y, start=1) + np.testing.assert_array_equal(y, y_normalized_start) From 02bdecfd6e3b2a1add407173e278cf9f6ee678ab Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:31:39 +0100 Subject: [PATCH 09/14] Fix comment regarding label transformation --- ot/da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index ce48b47b1..bb43623c4 100644 --- a/ot/da.py +++ b/ot/da.py @@ -738,7 +738,7 @@ def inverse_transform_labels(self, yt=None): # set nans to 0 transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) - # compute propagated samples + # compute propagated labels labels = label_normalization(yt) masks = labels_to_masks(labels, nx=nx, type_as=transp) transp_ys = nx.dot(masks.T, transp.T) From 61bbb3c2927f75bedb0dc964f04c3820dc8822d8 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 13:46:28 +0100 Subject: [PATCH 10/14] Update RELEASES --- RELEASES.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 3c428c521..07b9998c0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,10 @@ # Releases +## Next Release + +#### New features ++ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587) + ## 0.9.2dev #### New features From 8625b30b0f7bf177d7739e0306dac05749315094 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 17:09:39 +0100 Subject: [PATCH 11/14] Additional backend tests for nan_to_num --- test/test_backend.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index 605e30ad8..3bc1e5480 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -264,6 +264,8 @@ def test_empty_backend(): nx.detach(M) with pytest.raises(NotImplementedError): nx.matmul(M, M.T) + with pytest.raises(NotImplementedError): + nx.nan_to_num(M) def test_func_backends(nx): @@ -667,6 +669,11 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matmul broadcast") + vec = nx.from_numpy(np.array([1, np.nan, -1])) + vec = nx.nan_to_num(vec, nan=0) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("nan_to_num") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( From dfc79f05e07c1e0466b240376fe6d47c94f26c88 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 3 Dec 2023 17:12:05 +0100 Subject: [PATCH 12/14] min(unique(y)) === min(y) --- ot/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index 9de6e06cd..19e61f1fe 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -409,7 +409,7 @@ def label_normalization(y, start=0, nx=None): """ if nx is None: nx = get_backend(y) - diff = nx.min(nx.unique(y)) - start + diff = nx.min(y) - start return y if diff == 0 else (y - diff) From 308a5b4790418808dd4f0d479429f6b276f6e64e Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 14 Dec 2023 09:04:46 +0100 Subject: [PATCH 13/14] Avoid catching all warnings as JAX throws deprecation --- test/test_da.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index b5e5a838a..a7364e70b 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -175,11 +175,9 @@ def test_sinkhorn_l1l2_transport_class(nx): Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500) + otda.fit(Xs=Xs, ys=ys, Xt=Xt) # test its computed - with warnings.catch_warnings(): - warnings.simplefilter("error") - otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert hasattr(otda, "cost_") assert hasattr(otda, "coupling_") assert hasattr(otda, "log_") From 7c168747767e9d14c6e6c9e74bba682867bfffb0 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 14 Dec 2023 09:05:22 +0100 Subject: [PATCH 14/14] No need to import warnings module --- test/test_da.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_da.py b/test/test_da.py index a7364e70b..0ef5db79e 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -7,7 +7,6 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal import pytest -import warnings import ot from ot.datasets import make_data_classif