forked from sisl/aa228-notebook
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbandits.jl
More file actions
77 lines (71 loc) · 2.57 KB
/
bandits.jl
File metadata and controls
77 lines (71 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
type Bandit
θ::Vector{Float64}
end
Bandit(k::Integer) = Bandit(rand(k))
pull(b::Bandit, i::Integer) = rand() < b.θ[i]
numArms(b::Bandit) = length(b.θ)
function banditTrial(b)
B = [button("Arm $i") for i = 1:numArms(b)]
wins = [foldl((acc, value) -> acc + pull(b,i), 0, signal(B[i])) for i = 1:arms]
tries = [foldl((acc, value) -> acc + 1, 0, signal(B[i])) for i = 1:arms]
for i = 1:numArms(b)
display(B[i])
display(lift((w,t) -> latex(@sprintf("%d wins out of %d tries (%d percent)", w, t, 100*w/t)), wins[i], tries[i]))
end
t = togglebuttons(["Hide", "Show"], value="Hide", label="True parameters")
display(t)
display(lift(v -> v == "Show" ? latex(string(b.θ)) : latex(""), t))
end
function banditEstimation(b)
B = [button("Arm $i") for i = 1:numArms(b)]
wins = [foldl((acc, value) -> acc + pull(b,i), 0, signal(B[i])) for i = 1:arms]
tries = [foldl((acc, value) -> acc + 1, 0, signal(B[i])) for i = 1:arms]
for i = 1:numArms(b)
display(B[i])
display(lift((w,t) -> latex(@sprintf("%d wins out of %d tries (%d percent)", w, t, 100*w/t)), wins[i], tries[i]))
end
display(lift((w1,t1,w2,t2)->
Axis([
Plots.Linear(θ->pdf(Beta(w1+1, t1-w1+1), θ), (0,1), legendentry="Beta($(w1+1), $(t1-w1+1))"),
Plots.Linear(θ->pdf(Beta(w2+1, t2-w2+1), θ), (0,1), legendentry="Beta($(w2+1), $(t2-w2+1))")
],
xmin=0,xmax=1,ymin=0),
wins[1], tries[1], wins[2], tries[2]
))
t = togglebuttons(["Hide", "Show"], value="Hide", label="True parameters")
display(t)
display(lift(v -> v == "Show" ? latex(string(b.θ)) : latex(""), t))
end
type BanditStatistics
numWins::Vector{Int}
numTries::Vector{Int}
BanditStatistics(k::Int) = new(zeros(k), zeros(k))
end
numArms(b::BanditStatistics) = length(b.numWins)
function update!(b::BanditStatistics, i::Int, success::Bool)
b.numTries[i] += 1
if success
b.numWins[i] += 1
end
end
# win probability assuming uniform prior
winProbabilities(b::BanditStatistics) = (b.numWins + 1)./(b.numTries + 2)
abstract BanditPolicy
function simulate(b::Bandit, policy::BanditPolicy; steps = 10)
wins = zeros(steps)
s = BanditStatistics(numArms(b))
for step = 1:steps
i = arm(policy, s)
win = pull(b, i)
update!(s, i, win)
wins[step] = wins[max(1, step-1)] + (win ? 1 : 0)
end
wins
end
function simulateAverage(b::Bandit, policy::BanditPolicy; steps = 10, iterations = 10)
ret = zeros(steps)
for i = 1:iterations
ret += simulate(b, policy, steps=steps, steps=steps)
end
ret ./ iterations
end