Skip to content

Validate loss_fun parameter in Gromov-Wasserstein utils #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
constC : array-like, shape (ns, nt)
Expand Down Expand Up @@ -118,6 +119,8 @@ def h1(a):

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

constC1 = nx.dot(
nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
Expand Down Expand Up @@ -402,11 +405,12 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
T : array-like, shape (ns, nt)
Coupling between source and target spaces
p : array-like, shape (ns,)
loss_fun : str, optional
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
constC : array-like, shape (ns, nt)
Expand Down Expand Up @@ -446,6 +450,10 @@ def h1(a):

def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
raise NotImplementedError()
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")

constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, C2.shape[0]), type_as=p))
Expand Down
35 changes: 35 additions & 0 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ def line_search(cost, G, deltaG, Mi, cost_G):
np.testing.assert_allclose(res, Gb, atol=1e-06)


@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_gw_helper_validation(loss_fun):
n_samples = 20 # nb samples
mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
p = ot.unif(n_samples)
q = ot.unif(n_samples)
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun)


@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
Expand Down Expand Up @@ -2026,6 +2044,23 @@ def line_search(cost, G, deltaG, Mi, cost_G):
np.testing.assert_allclose(res, Gb, atol=1e-06)


@pytest.mark.parametrize('loss_fun', [
'square_loss',
pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_gw_semirelaxed_helper_validation(loss_fun):
n_samples = 20 # nb samples
mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
p = ot.unif(n_samples)
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun)


def test_semirelaxed_fgw(nx):
rng = np.random.RandomState(0)
list_n = [16, 8]
Expand Down