@@ -46,13 +46,12 @@ def val(self, x, must_be_cached=False):
4646 return self .vals .eval (x , lambda y : self .f (self .x0 + y * self .dir ))
4747
4848 def full_gradient (self , x , must_be_cached = False ):
49- return self .grads .eval (
50- x , lambda y : self .fprime (
51- self .x0 + y * self .dir ))
49+ return self .grads .eval (x , lambda y : self .fprime (self .x0 + y * self .dir ))
5250
5351 def der (self , x ):
5452 return np .dot (self .dir , self .full_gradient (x ))
5553
54+
5655# Finds mnimum of the function f using .
5756# Arguments:
5857# f - callable for function f.
@@ -77,18 +76,24 @@ def der(self, x):
7776# search.
7877
7978
80- def minimize_hz (f , x0 , fprime , maxiter = 1000 , gtol = 1e-4 ,
81- delta = 0.1 ,
82- sigma = 0.9 ,
83- eps = 1e-6 ,
84- theta = 0.5 ,
85- gamma = 0.66 ,
86- eta = 0.01 ,
87- rho = 5.0 ,
88- psi_0 = 0.01 ,
89- psi_1 = 0.1 ,
90- psi_2 = 2.0 ,
91- quad_step = True ):
79+ def minimize_hz (
80+ f ,
81+ x0 ,
82+ fprime ,
83+ maxiter = 1000 ,
84+ gtol = 1e-4 ,
85+ delta = 0.1 ,
86+ sigma = 0.9 ,
87+ eps = 1e-6 ,
88+ theta = 0.5 ,
89+ gamma = 0.66 ,
90+ eta = 0.01 ,
91+ rho = 5.0 ,
92+ psi_0 = 0.01 ,
93+ psi_1 = 0.1 ,
94+ psi_2 = 2.0 ,
95+ quad_step = True ,
96+ ):
9297 x_k = x0
9398 f_k = f (x_k )
9499 g_k = fprime (x_k )
@@ -122,19 +127,14 @@ def minimize_hz(f, x0, fprime, maxiter=1000, gtol=1e-4,
122127 phi_a0 = phi .val (a0 )
123128 q_koef = phi_a0 - phi_0 - a0 * derphi_0
124129 if phi_a0 <= phi_0 and q_koef > 0 :
125- c = - 0.5 * (derphi_0 * a0 ** 2 ) / q_koef
130+ c = - 0.5 * (derphi_0 * a0 ** 2 ) / q_koef
126131 if c is None :
127132 c = psi_2 * a_km1
128133
129134 # Line search.
130- a_k = hager_zhang_line_search (phi ,
131- initial_guess = c ,
132- delta = delta ,
133- sigma = sigma ,
134- eps = eps ,
135- theta = theta ,
136- gamma = gamma ,
137- rho = rho )
135+ a_k = hager_zhang_line_search (
136+ phi , initial_guess = c , delta = delta , sigma = sigma , eps = eps , theta = theta , gamma = gamma , rho = rho
137+ )
138138
139139 # Evalutaing new direction.
140140 x_kp1 = x_k + a_k * d_k
@@ -158,18 +158,11 @@ def minimize_hz(f, x0, fprime, maxiter=1000, gtol=1e-4,
158158 d_k = d_kp1
159159 f_k = f_kp1
160160
161- #print('Done %d iterations.' % k)
161+ # print('Done %d iterations.' % k)
162162 return x_k
163163
164164
165- def hager_zhang_line_search (phi ,
166- initial_guess = 1.0 ,
167- delta = 0.1 ,
168- sigma = 0.9 ,
169- eps = 1e-6 ,
170- theta = 0.5 ,
171- gamma = 0.66 ,
172- rho = 5.0 ):
165+ def hager_zhang_line_search (phi , initial_guess = 1.0 , delta = 0.1 , sigma = 0.9 , eps = 1e-6 , theta = 0.5 , gamma = 0.66 , rho = 5.0 ):
173166 phi_0 = phi .val (0 )
174167 derphi_0 = phi .der (0 )
175168 eps_k = eps * abs (phi_0 )
@@ -185,7 +178,7 @@ def interval_update_u3(a, b):
185178 a = d
186179 else :
187180 b = d
188- print (' Warning. Iterations exceeded in interval_update_u3.' )
181+ print (" Warning. Iterations exceeded in interval_update_u3." )
189182 return a , b
190183
191184 # Given inital guess [0, c], reduces it to [a,b], for which opposite slope
@@ -202,11 +195,11 @@ def bracket(c):
202195 if phi .val (c_j ) <= phi_0 + eps_k :
203196 c_i = c_j
204197 c_j = rho * c_j
205- print (' Warning. Iterations exceeded in bracket.' )
198+ print (" Warning. Iterations exceeded in bracket." )
206199 return c_i , c_j
207200
208201 def interval_update (a , b , c ):
209- assert ( a < b )
202+ assert a < b
210203 if c <= a or c >= b :
211204 return a , b
212205 if phi .der (c ) >= 0 :
@@ -247,8 +240,7 @@ def double_secant(a, b):
247240 break
248241
249242 # L0. Check T2 (approx Wolfe):
250- if (2 * delta - 1 ) * derphi_0 >= derphi_a and derphi_a >= sigma * \
251- derphi_0 and phi_a <= phi_0 + eps_k :
243+ if (2 * delta - 1 ) * derphi_0 >= derphi_a and derphi_a >= sigma * derphi_0 and phi_a <= phi_0 + eps_k :
252244 break
253245
254246 # L1.
0 commit comments