-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathxnes.py
More file actions
196 lines (159 loc) · 6.32 KB
/
xnes.py
File metadata and controls
196 lines (159 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
xNES from 'Natural Evolution Strategies'
if n_jobs>1, I suggest using "export MKL_NUM_THREADS=1"
See at the bottom (under __main__) for an example of usage
"""
import joblib
import random
import numpy as np
import scipy as sp
from scipy import (dot, eye, randn, asarray, array, trace, log, exp, sqrt, mean, sum, argsort, square, arange)
from scipy.stats import multivariate_normal, norm
from scipy.linalg import (det, expm)
class XNES(object):
def __init__(self, f, mu, amat,
eta_mu=1.0, eta_sigma=None, eta_bmat=None,
npop=None, use_fshape=True, use_adasam=False, patience=100, n_jobs=1):
self.f = f
self.mu = mu
self.eta_mu = eta_mu
self.use_adasam = use_adasam
self.n_jobs = n_jobs
dim = len(mu)
sigma = abs(det(amat))**(1.0/dim)
bmat = amat*(1.0/sigma)
self.dim = dim
self.sigma = sigma
self.bmat = bmat
# default population size and learning rates
npop = int(4 + 3*log(dim)) if npop is None else npop
eta_sigma = 3*(3+log(dim))*(1.0/(5*dim*sqrt(dim))) if eta_sigma is None else eta_sigma
eta_bmat = 3*(3+log(dim))*(1.0/(5*dim*sqrt(dim))) if eta_bmat is None else eta_bmat
self.npop = npop
self.eta_sigma = eta_sigma
self.eta_bmat = eta_bmat
# compute utilities if using fitness shaping
if use_fshape:
a = log(1+0.5*npop)
utilities = array([max(0, a-log(k)) for k in range(1,npop+1)])
utilities /= sum(utilities)
utilities -= 1.0/npop # broadcast
utilities = utilities[::-1] # ascending order
else:
utilities = None
self.use_fshape = use_fshape
self.utilities = utilities
# stuff for adasam
self.eta_sigma_init = eta_sigma
self.sigma_old = None
# logging
self.fitness_best = None
self.mu_best = None
self.done = False
self.counter = 0
self.patience = patience
self.history = {'eta_sigma':[], 'sigma':[], 'fitness':[]}
# do not use these when hill-climbing
if npop == 1:
self.use_fshape = False
self.use_adasam = False
def step(self, niter):
""" xNES """
f = self.f
mu, sigma, bmat = self.mu, self.sigma, self.bmat
eta_mu, eta_sigma, eta_bmat = self.eta_mu, self.eta_sigma, self.eta_bmat
npop = self.npop
dim = self.dim
sigma_old = self.sigma_old
eyemat = eye(dim)
with joblib.Parallel(n_jobs=self.n_jobs) as parallel:
for i in range(niter):
s_try = randn(npop, dim)
z_try = mu + sigma * dot(s_try, bmat) # broadcast
f_try = parallel(joblib.delayed(f)(z) for z in z_try)
f_try = asarray(f_try)
# save if best
fitness = mean(f_try)
if fitness - 1e-8 > self.fitness_best:
self.fitness_best = fitness
self.mu_best = mu.copy()
self.counter = 0
else: self.counter += 1
if self.counter > self.patience:
self.done = True
return
isort = argsort(f_try)
f_try = f_try[isort]
s_try = s_try[isort]
z_try = z_try[isort]
u_try = self.utilities if self.use_fshape else f_try
if self.use_adasam and sigma_old is not None: # sigma_old must be available
eta_sigma = self.adasam(eta_sigma, mu, sigma, bmat, sigma_old, z_try)
dj_delta = dot(u_try, s_try)
dj_mmat = dot(s_try.T, s_try*u_try.reshape(npop,1)) - sum(u_try)*eyemat
dj_sigma = trace(dj_mmat)*(1.0/dim)
dj_bmat = dj_mmat - dj_sigma*eyemat
sigma_old = sigma
# update
mu += eta_mu * sigma * dot(bmat, dj_delta)
sigma *= exp(0.5 * eta_sigma * dj_sigma)
bmat = dot(bmat, expm(0.5 * eta_bmat * dj_bmat))
# logging
self.history['fitness'].append(fitness)
self.history['sigma'].append(sigma)
self.history['eta_sigma'].append(eta_sigma)
# keep last results
self.mu, self.sigma, self.bmat = mu, sigma, bmat
self.eta_sigma = eta_sigma
self.sigma_old = sigma_old
def adasam(self, eta_sigma, mu, sigma, bmat, sigma_old, z_try):
""" Adaptation sampling """
eta_sigma_init = self.eta_sigma_init
dim = self.dim
c = .1
rho = 0.5 - 1./(3*(dim+1)) # empirical
bbmat = dot(bmat.T, bmat)
cov = sigma**2 * bbmat
sigma_ = sigma * sqrt(sigma*(1./sigma_old)) # increase by 1.5
cov_ = sigma_**2 * bbmat
p0 = multivariate_normal.logpdf(z_try, mean=mu, cov=cov)
p1 = multivariate_normal.logpdf(z_try, mean=mu, cov=cov_)
w = exp(p1-p0)
# Mann-Whitney. It is assumed z_try was in ascending order.
n = self.npop
n_ = sum(w)
u_ = sum(w * (arange(n)+0.5))
u_mu = n*n_*0.5
u_sigma = sqrt(n*n_*(n+n_+1)/12.)
cum = norm.cdf(u_, loc=u_mu, scale=u_sigma)
if cum < rho:
return (1-c)*eta_sigma + c*eta_sigma_init
else:
return min(1, (1+c)*eta_sigma)
if __name__ == '__main__':
''' Example of usage code '''
import time
np.random.seed(42)
random.seed(42)
def f(x): # sin(x^2+y^2)/(x^2+y^2)
r = sum(square(x))
return sp.sin(r)/r
mu = array([9999.,-9999.]) # a bad init guess
amat = eye(2)
# when adasam, use conservative eta
xnes = XNES(f, mu, amat, npop=50, use_adasam=True, eta_bmat=0.01, eta_sigma=.1, patience=9999)
t0 = time.time()
for i in range(20):
xnes.step(100)
print "Current: ({},{})".format(*xnes.mu)
print("Exact solution is (0,0)")
print("Took {} secs".format(time.time()-t0))
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3,1)
axs[0].plot(xnes.history['fitness'])
axs[1].plot(xnes.history['sigma'])
axs[2].plot(xnes.history['eta_sigma'])
axs[0].set_ylabel('fitness')
axs[1].set_ylabel(r'$\sigma$')
axs[2].set_ylabel(r'$\eta_{\sigma}$')
fig.show()