Skip to content

[DA] Sinkhorn L1L2 transport to work on JAX #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Releases

## Next Release

#### New features
+ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587)

## 0.9.2dev

#### New features
Expand Down
27 changes: 27 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,14 @@ def matmul(self, a, b):
"""
raise NotImplementedError()

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
r"""
Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num
"""
raise NotImplementedError()


class NumpyBackend(Backend):
"""
Expand Down Expand Up @@ -1392,6 +1400,9 @@ def detach(self, *args):
def matmul(self, a, b):
return np.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


_register_backend_implementation(NumpyBackend)

Expand Down Expand Up @@ -1762,6 +1773,9 @@ def detach(self, *args):
def matmul(self, a, b):
return jnp.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


if jax:
# Only register jax backend if it is installed
Expand Down Expand Up @@ -2250,6 +2264,10 @@ def detach(self, *args):
def matmul(self, a, b):
return torch.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
out = None if copy else x
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out)


if torch:
# Only register torch backend if it is installed
Expand Down Expand Up @@ -2647,6 +2665,9 @@ def detach(self, *args):
def matmul(self, a, b):
return cp.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


if cp:
# Only register cp backend if it is installed
Expand Down Expand Up @@ -3070,6 +3091,12 @@ def detach(self, *args):
def matmul(self, a, b):
return tnp.matmul(a, b)

# todo(okachaiev): replace this with a more reasonable implementation
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
x = self.to_numpy(x)
x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
return self.from_numpy(x)


if tf:
# Only register tensorflow backend if it is installed
Expand Down
73 changes: 24 additions & 49 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array
from .unbalanced import sinkhorn_unbalanced
from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping
from .optim import cg
Expand Down Expand Up @@ -499,18 +499,12 @@ class label
if self.limit_max != np.infty:
self.limit_max = self.limit_max * nx.max(self.cost_)

# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
classes = [c for c in nx.unique(ys) if c != -1]
for c in classes:
idx_s = nx.where((ys != c) & (ys != -1))
idx_t = nx.where(yt == c)

# all the coefficients corresponding to a source sample
# and a target sample :
# with different labels get a infinite
for j in idx_t[0]:
self.cost_[idx_s[0], j] = self.limit_max
# zeros where source label is missing (masked with -1)
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
# zeros where labels match
label_match = ys[:, None] - yt[None, :]
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)

# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
Expand Down Expand Up @@ -581,12 +575,11 @@ class label
if check_params(Xs=Xs):

if nx.array_equal(self.xs_, Xs):

# perform standard barycentric mapping
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xs = nx.dot(transp, self.xt_)
Expand All @@ -604,9 +597,8 @@ class label
idx = nx.argmin(D0, axis=1)

# transport the source samples
transp = self.coupling_ / nx.sum(
self.coupling_, axis=1)[:, None]
transp[~ nx.isfinite(transp)] = 0
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
transp_Xs_ = nx.dot(transp, self.xt_)

# define the transported points
Expand Down Expand Up @@ -645,23 +637,16 @@ def transform_labels(self, ys=None):

# check the necessary inputs parameters are here
if check_params(ys=ys):

ysTemp = label_normalization(nx.copy(ys))
classes = nx.unique(ysTemp)
n = len(classes)
D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)

# perform label propagation
transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0

for c in classes:
D1[int(c), ysTemp == c] = 1
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute propagated labels
transp_ys = nx.dot(D1, transp)
labels = label_normalization(ys)
masks = labels_to_masks(labels, nx=nx, type_as=transp)
transp_ys = nx.dot(masks.T, transp)

return transp_ys.T

Expand Down Expand Up @@ -697,12 +682,11 @@ class label
if check_params(Xt=Xt):

if nx.array_equal(self.xt_, Xt):

# perform standard barycentric mapping
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]

# set nans to 0
transp_[~ nx.isfinite(transp_)] = 0
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xt = nx.dot(transp_, self.xs_)
Expand All @@ -719,9 +703,8 @@ class label
idx = nx.argmin(D0, axis=1)

# transport the target samples
transp_ = self.coupling_.T / nx.sum(
self.coupling_, 0)[:, None]
transp_[~ nx.isfinite(transp_)] = 0
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
transp_Xt_ = nx.dot(transp_, self.xs_)

# define the transported points
Expand Down Expand Up @@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None):

# check the necessary inputs parameters are here
if check_params(yt=yt):

ytTemp = label_normalization(nx.copy(yt))
classes = nx.unique(ytTemp)
n = len(classes)
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)

# perform label propagation
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

for c in classes:
D1[int(c), ytTemp == c] = 1

# compute propagated samples
transp_ys = nx.dot(D1, transp.T)
# compute propagated labels
labels = label_normalization(yt)
masks = labels_to_masks(labels, nx=nx, type_as=transp)
transp_ys = nx.dot(masks.T, transp.T)

return transp_ys.T

Expand Down Expand Up @@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None):
type_as=ys[0]
)
for i in range(len(ys)):
ysTemp = label_normalization(nx.copy(ys[i]))
ysTemp = label_normalization(ys[i])
classes = nx.unique(ysTemp)
n = len(classes)
ns = len(ysTemp)
Expand Down Expand Up @@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
ytTemp = label_normalization(nx.copy(yt))
ytTemp = label_normalization(yt)
classes = nx.unique(ytTemp)
n = len(classes)
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])
Expand Down
45 changes: 35 additions & 10 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def is_all_finite(*args):
return all(not nx.any(~nx.isfinite(arg)) for arg in args)


def label_normalization(y, start=0):
def label_normalization(y, start=0, nx=None):
r""" Transform labels to start at a given value

Parameters
Expand All @@ -399,18 +399,45 @@ def label_normalization(y, start=0):
The vector of labels to be normalized.
start : int
Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `y`.

Returns
-------
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
nx = get_backend(y)
if nx is None:
nx = get_backend(y)
diff = nx.min(y) - start
return y if diff == 0 else (y - diff)


def labels_to_masks(y, type_as=None, nx=None):
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.

Parameters
----------
y : array-like, shape (n_samples, )
The vector of labels.
type_as : array_like
Array of the same type of the expected output.
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `y`.

diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
Returns
-------
masks : array-like, shape (n_samples, n_labels)
The (n_samples, n_labels) matrix of label masks.
"""
if nx is None:
nx = get_backend(y)
if type_as is None:
type_as = y
labels_u, labels_idx = nx.unique(y, return_inverse=True)
n_labels = labels_u.shape[0]
masks = nx.eye(n_labels, type_as=type_as)[labels_idx]
return masks


def parmap(f, X, nprocs="default"):
Expand Down Expand Up @@ -755,10 +782,8 @@ def _get_backend(self, *arrays):
nx = get_backend(
*[input_ for input_ in arrays if input_ is not None]
)
if nx.__name__ in ("jax", "tf"):
raise TypeError(
"""JAX or TF arrays have been received but domain
adaptation does not support those backend.""")
if nx.__name__ in ("tf",):
raise TypeError("Domain adaptation does not support TF backend.")
self.nx = nx
return nx

Expand Down
7 changes: 7 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def test_empty_backend():
nx.detach(M)
with pytest.raises(NotImplementedError):
nx.matmul(M, M.T)
with pytest.raises(NotImplementedError):
nx.nan_to_num(M)


def test_func_backends(nx):
Expand Down Expand Up @@ -667,6 +669,11 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matmul broadcast")

vec = nx.from_numpy(np.array([1, np.nan, -1]))
vec = nx.nan_to_num(vec, nan=0)
lst_b.append(nx.to_numpy(vec))
lst_name.append("nan_to_num")

assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
Expand Down
19 changes: 8 additions & 11 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import warnings

import ot
from ot.datasets import make_data_classif
Expand Down Expand Up @@ -158,7 +157,6 @@ def test_sinkhorn_lpl1_transport_class(nx):
assert mass_semi == 0, "semisupervised mode not working"


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_l1l2_transport_class(nx):
"""test_sinkhorn_transport
Expand All @@ -169,15 +167,16 @@ def test_sinkhorn_l1l2_transport_class(nx):

Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
# prepare semi-supervised labels
yt_semi = np.copy(yt)
yt_semi[np.arange(0, nt, 2)] = -1

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)

otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)
otda.fit(Xs=Xs, ys=ys, Xt=Xt)

# test its computed
with warnings.catch_warnings():
warnings.simplefilter("error")
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert hasattr(otda, "coupling_")
assert hasattr(otda, "log_")
Expand Down Expand Up @@ -234,7 +233,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.SinkhornL1l2Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = nx.sum(otda_semi.cost_)

Expand All @@ -243,11 +242,9 @@ def test_sinkhorn_l1l2_transport_class(nx):

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-9, atol=1e-9)
assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9)

# check everything runs well with log=True
otda = ot.da.SinkhornL1l2Transport(log=True)
Expand Down
Loading