Skip to content

Commit 9c15e96

Browse files
committed
Format code
1 parent d13e1c2 commit 9c15e96

6 files changed

Lines changed: 261 additions & 204 deletions

File tree

.github/workflows/ci.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ on:
77
branches: [ master ]
88

99
jobs:
10+
format-check:
11+
name: Format check
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: actions/checkout@v4
15+
- name: Install black
16+
run: pip install black
17+
- name: Run black
18+
run: black --check --diff --line-length=120 .
1019
tests-ubuntu:
1120
name: Tests (Ubuntu)
1221
runs-on: ubuntu-latest

numopt/cg_hager_zhang.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@ def val(self, x, must_be_cached=False):
4646
return self.vals.eval(x, lambda y: self.f(self.x0 + y * self.dir))
4747

4848
def full_gradient(self, x, must_be_cached=False):
49-
return self.grads.eval(
50-
x, lambda y: self.fprime(
51-
self.x0 + y * self.dir))
49+
return self.grads.eval(x, lambda y: self.fprime(self.x0 + y * self.dir))
5250

5351
def der(self, x):
5452
return np.dot(self.dir, self.full_gradient(x))
5553

54+
5655
# Finds mnimum of the function f using .
5756
# Arguments:
5857
# f - callable for function f.
@@ -77,18 +76,24 @@ def der(self, x):
7776
# search.
7877

7978

80-
def minimize_hz(f, x0, fprime, maxiter=1000, gtol=1e-4,
81-
delta=0.1,
82-
sigma=0.9,
83-
eps=1e-6,
84-
theta=0.5,
85-
gamma=0.66,
86-
eta=0.01,
87-
rho=5.0,
88-
psi_0=0.01,
89-
psi_1=0.1,
90-
psi_2=2.0,
91-
quad_step=True):
79+
def minimize_hz(
80+
f,
81+
x0,
82+
fprime,
83+
maxiter=1000,
84+
gtol=1e-4,
85+
delta=0.1,
86+
sigma=0.9,
87+
eps=1e-6,
88+
theta=0.5,
89+
gamma=0.66,
90+
eta=0.01,
91+
rho=5.0,
92+
psi_0=0.01,
93+
psi_1=0.1,
94+
psi_2=2.0,
95+
quad_step=True,
96+
):
9297
x_k = x0
9398
f_k = f(x_k)
9499
g_k = fprime(x_k)
@@ -122,19 +127,14 @@ def minimize_hz(f, x0, fprime, maxiter=1000, gtol=1e-4,
122127
phi_a0 = phi.val(a0)
123128
q_koef = phi_a0 - phi_0 - a0 * derphi_0
124129
if phi_a0 <= phi_0 and q_koef > 0:
125-
c = - 0.5 * (derphi_0 * a0**2) / q_koef
130+
c = -0.5 * (derphi_0 * a0**2) / q_koef
126131
if c is None:
127132
c = psi_2 * a_km1
128133

129134
# Line search.
130-
a_k = hager_zhang_line_search(phi,
131-
initial_guess=c,
132-
delta=delta,
133-
sigma=sigma,
134-
eps=eps,
135-
theta=theta,
136-
gamma=gamma,
137-
rho=rho)
135+
a_k = hager_zhang_line_search(
136+
phi, initial_guess=c, delta=delta, sigma=sigma, eps=eps, theta=theta, gamma=gamma, rho=rho
137+
)
138138

139139
# Evalutaing new direction.
140140
x_kp1 = x_k + a_k * d_k
@@ -158,18 +158,11 @@ def minimize_hz(f, x0, fprime, maxiter=1000, gtol=1e-4,
158158
d_k = d_kp1
159159
f_k = f_kp1
160160

161-
#print('Done %d iterations.' % k)
161+
# print('Done %d iterations.' % k)
162162
return x_k
163163

164164

165-
def hager_zhang_line_search(phi,
166-
initial_guess=1.0,
167-
delta=0.1,
168-
sigma=0.9,
169-
eps=1e-6,
170-
theta=0.5,
171-
gamma=0.66,
172-
rho=5.0):
165+
def hager_zhang_line_search(phi, initial_guess=1.0, delta=0.1, sigma=0.9, eps=1e-6, theta=0.5, gamma=0.66, rho=5.0):
173166
phi_0 = phi.val(0)
174167
derphi_0 = phi.der(0)
175168
eps_k = eps * abs(phi_0)
@@ -185,7 +178,7 @@ def interval_update_u3(a, b):
185178
a = d
186179
else:
187180
b = d
188-
print('Warning. Iterations exceeded in interval_update_u3.')
181+
print("Warning. Iterations exceeded in interval_update_u3.")
189182
return a, b
190183

191184
# Given inital guess [0, c], reduces it to [a,b], for which opposite slope
@@ -202,11 +195,11 @@ def bracket(c):
202195
if phi.val(c_j) <= phi_0 + eps_k:
203196
c_i = c_j
204197
c_j = rho * c_j
205-
print('Warning. Iterations exceeded in bracket.')
198+
print("Warning. Iterations exceeded in bracket.")
206199
return c_i, c_j
207200

208201
def interval_update(a, b, c):
209-
assert (a < b)
202+
assert a < b
210203
if c <= a or c >= b:
211204
return a, b
212205
if phi.der(c) >= 0:
@@ -247,8 +240,7 @@ def double_secant(a, b):
247240
break
248241

249242
# L0. Check T2 (approx Wolfe):
250-
if (2 * delta - 1) * derphi_0 >= derphi_a and derphi_a >= sigma * \
251-
derphi_0 and phi_a <= phi_0 + eps_k:
243+
if (2 * delta - 1) * derphi_0 >= derphi_a and derphi_a >= sigma * derphi_0 and phi_a <= phi_0 + eps_k:
252244
break
253245

254246
# L1.

0 commit comments

Comments
 (0)