Skip to content

Commit 526b72f

Browse files
kachayevrflamary
andauthored
[Fix] Prevent line search from evaluating cost outside of the interpolation range (#504)
* Explicitly check that SinkhornL1l2Transport.fit works with no warnings * Default value for alpha_min is set to 0 * Fix random_state for SinkhornL1l2Transport test * Mention changes in releases --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 9e74f2e commit 526b72f

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
77

88
#### Closed issues
9+
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
910

1011

1112
## 0.9.1

ot/optim.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def line_search_armijo(
2929
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
30-
alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
30+
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
3131
):
3232
r"""
3333
Armijo linesearch function that works with matrices
@@ -56,7 +56,7 @@ def line_search_armijo(
5656
:math:`c_1` const in armijo rule (>0)
5757
alpha0 : float, optional
5858
initial step (>0)
59-
alpha_min : float, optional
59+
alpha_min : float, default=0.
6060
minimum value for alpha
6161
alpha_max : float, optional
6262
maximum value for alpha
@@ -89,6 +89,14 @@ def line_search_armijo(
8989
fc = [0]
9090

9191
def phi(alpha1):
92+
# it's necessary to check boundary condition here for the coefficient
93+
# as the callback could be evaluated for negative value of alpha by
94+
# `scalar_search_armijo` function here:
95+
#
96+
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
97+
#
98+
# see more details https://github.com/PythonOT/POT/issues/502
99+
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
92100
# The callable function operates on nx backend
93101
fc[0] += 1
94102
alpha10 = nx.from_numpy(alpha1)
@@ -109,13 +117,12 @@ def phi(alpha1):
109117

110118
derphi0 = np.sum(pk * gfk) # Quickfix for matrices
111119
alpha, phi1 = scalar_search_armijo(
112-
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
120+
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)
113121

114122
if alpha is None:
115123
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
116124
else:
117-
if alpha_min is not None or alpha_max is not None:
118-
alpha = np.clip(alpha, alpha_min, alpha_max)
125+
alpha = np.clip(alpha, alpha_min, alpha_max)
119126
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
120127

121128

test/test_da.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
from numpy.testing import assert_allclose, assert_equal
99
import pytest
10+
import warnings
1011

1112
import ot
1213
from ot.datasets import make_data_classif
@@ -158,15 +159,17 @@ def test_sinkhorn_l1l2_transport_class(nx):
158159
ns = 50
159160
nt = 50
160161

161-
Xs, ys = make_data_classif('3gauss', ns)
162-
Xt, yt = make_data_classif('3gauss2', nt)
162+
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
163+
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
163164

164165
Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
165166

166-
otda = ot.da.SinkhornL1l2Transport()
167+
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)
167168

168169
# test its computed
169-
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
170+
with warnings.catch_warnings():
171+
warnings.simplefilter("error")
172+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
170173
assert hasattr(otda, "cost_")
171174
assert hasattr(otda, "coupling_")
172175
assert hasattr(otda, "log_")

0 commit comments

Comments
 (0)