41
41
# Ensure compatibility with Python 2
42
42
from __future__ import absolute_import , division , print_function , unicode_literals
43
43
44
- import logging
45
44
from math import sqrt
46
45
import numpy as np
46
+ try :
47
+ import trustregion
48
+ USE_FORTRAN = True
49
+ except ImportError :
50
+ # Fall back to Python implementation
51
+ USE_FORTRAN = False
47
52
48
53
49
54
from .util import sumsq , model_value
50
55
51
56
52
57
__all__ = ['trsbox' , 'trsbox_geometry' ]
53
58
54
- # ZERO_THRESH = 1e-14
59
+ ZERO_THRESH = 1e-14
55
60
56
61
57
- def trsbox (xopt , g , H , sl , su , delta ):
62
+ def trsbox (xopt , g , H , sl , su , delta , use_fortran = USE_FORTRAN ):
63
+ if use_fortran :
64
+ return trustregion .solve (g , H , delta ,
65
+ sl = np .minimum (sl - xopt , - ZERO_THRESH ),
66
+ su = np .maximum (su - xopt , ZERO_THRESH ),
67
+ verbose_output = True )
68
+
58
69
n = xopt .size
59
70
assert xopt .shape == (n ,), "xopt has wrong shape (should be vector)"
60
71
assert g .shape == (n ,), "g and xopt have incompatible sizes"
@@ -368,7 +379,7 @@ def d_within_bounds(d, xopt, sl, su, xbdi):
368
379
return d
369
380
370
381
371
- def trsbox_geometry (xbase , c , g , H , lower , upper , Delta ):
382
+ def trsbox_geometry (xbase , c , g , H , lower , upper , Delta , use_fortran = USE_FORTRAN ):
372
383
# Given a Lagrange polynomial defined by: L(x) = c + g' * (x - xbase) + 0.5*(x-xbase)*H*(x-xbase)
373
384
# Maximise |L(x)| in a box + trust region - that is, solve:
374
385
# max_x abs(c + g' * (x - xbase) + 0.5*(x-xbase)*H*(x-xbase))
@@ -378,8 +389,8 @@ def trsbox_geometry(xbase, c, g, H, lower, upper, Delta):
378
389
# max_s abs(c + g' * s + 0.5*s*H*s)
379
390
# s.t. lower <= xbase + s <= upper
380
391
# ||s|| <= Delta
381
- smin , gmin , crvmin = trsbox (xbase , g , H , lower , upper , Delta ) # minimise L(x)
382
- smax , gmax , crvmax = trsbox (xbase , - g , - H , lower , upper , Delta ) # maximise L(x)
392
+ smin , gmin , crvmin = trsbox (xbase , g , H , lower , upper , Delta , use_fortran = use_fortran ) # minimise L(x)
393
+ smax , gmax , crvmax = trsbox (xbase , - g , - H , lower , upper , Delta , use_fortran = use_fortran ) # maximise L(x)
383
394
if abs (c + model_value (g , H , smin )) >= abs (c + model_value (g , H , smax )): # take largest abs value
384
395
return xbase + smin
385
396
else :
0 commit comments