Skip to content

Commit 3478d0a

Browse files
committed
Default value for alpha_min is set to 0
1 parent f3324a6 commit 3478d0a

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

ot/optim.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def line_search_armijo(
2929
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
30-
alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
30+
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
3131
):
3232
r"""
3333
Armijo linesearch function that works with matrices
@@ -56,7 +56,7 @@ def line_search_armijo(
5656
:math:`c_1` const in armijo rule (>0)
5757
alpha0 : float, optional
5858
initial step (>0)
59-
alpha_min : float, optional
59+
alpha_min : float, default=0.
6060
minimum value for alpha
6161
alpha_max : float, optional
6262
maximum value for alpha
@@ -89,6 +89,14 @@ def line_search_armijo(
8989
fc = [0]
9090

9191
def phi(alpha1):
92+
# it's necessary to check boundary condition here for the coefficient
93+
# as the callback could be evaluated for negative value of alpha by
94+
# `scalar_search_armijo` function here:
95+
#
96+
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
97+
#
98+
# see more details https://github.com/PythonOT/POT/issues/502
99+
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
92100
# The callable function operates on nx backend
93101
fc[0] += 1
94102
alpha10 = nx.from_numpy(alpha1)
@@ -109,13 +117,12 @@ def phi(alpha1):
109117

110118
derphi0 = np.sum(pk * gfk) # Quickfix for matrices
111119
alpha, phi1 = scalar_search_armijo(
112-
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
120+
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)
113121

114122
if alpha is None:
115123
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
116124
else:
117-
if alpha_min is not None or alpha_max is not None:
118-
alpha = np.clip(alpha, alpha_min, alpha_max)
125+
alpha = np.clip(alpha, alpha_min, alpha_max)
119126
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
120127

121128

0 commit comments

Comments
 (0)