Skip to content

Commit cedb991

Browse files
Add Chapter 3 multi-armed bandits: env, algorithms, examples, tests
1 parent fed60e6 commit cedb991

7 files changed

Lines changed: 180 additions & 7 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .bandit_env import MultiArmedBanditBernoulli
2+
from .algorithms import EpsilonGreedy, UCB1, ThompsonSamplingBeta, simulate
Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,87 @@
1-
def epsilon_greedy(*args, **kwargs):
2-
return 0
1+
from __future__ import annotations
2+
import numpy as np
3+
from dataclasses import dataclass
4+
from typing import Optional, Dict, Any
5+
6+
@dataclass
7+
class EpsilonGreedy:
8+
n_arms: int
9+
epsilon: float = 0.1
10+
init: float = 0.0
11+
def __post_init__(self):
12+
self.counts = np.zeros(self.n_arms, dtype=int)
13+
self.values = np.full(self.n_arms, float(self.init), dtype=float)
14+
self.rng = np.random.default_rng()
15+
def select_arm(self) -> int:
16+
if self.rng.random() < self.epsilon:
17+
return int(self.rng.integers(self.n_arms))
18+
return int(np.argmax(self.values))
19+
def update(self, arm: int, reward: float):
20+
self.counts[arm] += 1
21+
n = self.counts[arm]
22+
self.values[arm] += (reward - self.values[arm]) / n
23+
24+
@dataclass
25+
class UCB1:
26+
n_arms: int
27+
c: float = 2.0
28+
def __post_init__(self):
29+
self.counts = np.zeros(self.n_arms, dtype=int)
30+
self.values = np.zeros(self.n_arms, dtype=float)
31+
self.total = 0
32+
self.rng = np.random.default_rng()
33+
def select_arm(self) -> int:
34+
# pull each arm once
35+
for a in range(self.n_arms):
36+
if self.counts[a] == 0:
37+
return a
38+
ucb = self.values + self.c * np.sqrt(np.log(self.total) / self.counts)
39+
return int(np.argmax(ucb))
40+
def update(self, arm: int, reward: float):
41+
self.total += 1
42+
self.counts[arm] += 1
43+
n = self.counts[arm]
44+
self.values[arm] += (reward - self.values[arm]) / n
45+
46+
@dataclass
47+
class ThompsonSamplingBeta:
48+
n_arms: int
49+
a0: float = 1.0
50+
b0: float = 1.0
51+
def __post_init__(self):
52+
self.a = np.full(self.n_arms, self.a0, dtype=float)
53+
self.b = np.full(self.n_arms, self.b0, dtype=float)
54+
self.rng = np.random.default_rng()
55+
def select_arm(self) -> int:
56+
samples = self.rng.beta(self.a, self.b)
57+
return int(np.argmax(samples))
58+
def update(self, arm: int, reward: float):
59+
self.a[arm] += reward
60+
self.b[arm] += 1 - reward
61+
62+
def simulate(env, agent, steps: int, seed: Optional[int] = None) -> Dict[str, Any]:
63+
"""Run interaction loop and return history stats."""
64+
if seed is not None:
65+
try:
66+
agent.rng = np.random.default_rng(seed)
67+
except AttributeError:
68+
pass
69+
if hasattr(env, 'rng'):
70+
env.rng = np.random.default_rng(seed+1 if seed is not None else None)
71+
rewards = np.zeros(steps, dtype=float)
72+
pulls = np.zeros(env.k, dtype=int)
73+
choices = np.zeros(steps, dtype=int)
74+
for t in range(steps):
75+
a = agent.select_arm()
76+
r = env.pull(a)
77+
agent.update(a, r)
78+
rewards[t] = r
79+
pulls[a] += 1
80+
choices[t] = a
81+
return {
82+
"avg_reward": float(rewards.mean()),
83+
"cum_reward": float(rewards.sum()),
84+
"pulls": pulls,
85+
"choices": choices,
86+
"rewards": rewards,
87+
}
Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,24 @@
1-
class MultiArmedBandit:
2-
pass
1+
from __future__ import annotations
2+
import numpy as np
3+
from typing import Sequence, Optional
4+
5+
class MultiArmedBanditBernoulli:
6+
"""K-armed bandit with Bernoulli rewards.
7+
probs: list/array of success probabilities for each arm (0..1).
8+
reward is 0/1. RNG can be injected for reproducibility.
9+
"""
10+
def __init__(self, probs: Sequence[float], rng: Optional[np.random.Generator] = None):
11+
self.probs = np.asarray(probs, dtype=float)
12+
assert np.all((0 <= self.probs) & (self.probs <= 1)), "probs must be in [0,1]"
13+
self.k = len(self.probs)
14+
self.rng = rng if rng is not None else np.random.default_rng()
15+
16+
def pull(self, arm: int) -> int:
17+
p = self.probs[arm]
18+
return int(self.rng.random() < p)
19+
20+
def best_arm(self) -> int:
21+
return int(np.argmax(self.probs))
22+
23+
def optimal_mean(self) -> float:
24+
return float(np.max(self.probs))
Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,19 @@
1-
print('Bandit demo placeholder')
1+
from ch3_multi_armed_bandits.bandit_env import MultiArmedBanditBernoulli
2+
from ch3_multi_armed_bandits.algorithms import EpsilonGreedy, UCB1, ThompsonSamplingBeta, simulate
3+
4+
def main():
5+
probs = [0.1, 0.2, 0.5, 0.4]
6+
env = MultiArmedBanditBernoulli(probs)
7+
best = env.best_arm()
8+
steps = 5000
9+
10+
for agent in [
11+
EpsilonGreedy(n_arms=len(probs), epsilon=0.1),
12+
UCB1(n_arms=len(probs)),
13+
ThompsonSamplingBeta(n_arms=len(probs)),
14+
]:
15+
stats = simulate(env, agent, steps, seed=123)
16+
print(f"{agent.__class__.__name__}: avg_reward={stats['avg_reward']:.3f}, best_arm_pulled={stats['pulls'][best]} times" )
17+
18+
if __name__ == "__main__":
19+
main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
numpy>=1.24
22
pytest>=7.0
3+
pytest-cov>=4.1

ch3_multi_armed_bandits/tests/test_bandit.py

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
from ch3_multi_armed_bandits.bandit_env import MultiArmedBanditBernoulli
3+
from ch3_multi_armed_bandits.algorithms import EpsilonGreedy, UCB1, ThompsonSamplingBeta, simulate
4+
5+
PROBS = [0.1, 0.2, 0.5, 0.4]
6+
BEST = int(np.argmax(PROBS))
7+
STEPS = 4000
8+
9+
def random_baseline_avg(probs, steps, seed=7):
10+
rng = np.random.default_rng(seed)
11+
k = len(probs)
12+
rewards = []
13+
for t in range(steps):
14+
a = int(rng.integers(k))
15+
r = int(rng.random() < probs[a])
16+
rewards.append(r)
17+
return float(np.mean(rewards))
18+
19+
def run_and_check(agent, steps=STEPS):
20+
env = MultiArmedBanditBernoulli(PROBS)
21+
stats = simulate(env, agent, steps, seed=42)
22+
return stats
23+
24+
def test_algorithms_beat_random_baseline():
25+
baseline = random_baseline_avg(PROBS, STEPS)
26+
agents = [
27+
EpsilonGreedy(n_arms=len(PROBS), epsilon=0.1),
28+
UCB1(n_arms=len(PROBS)),
29+
ThompsonSamplingBeta(n_arms=len(PROBS)),
30+
]
31+
for agent in agents:
32+
stats = run_and_check(agent)
33+
# Should beat random baseline by a margin
34+
assert stats['avg_reward'] >= baseline + 0.05, (agent.__class__.__name__, stats['avg_reward'], baseline)
35+
36+
def test_learn_best_arm_frequently():
37+
agents = [
38+
EpsilonGreedy(n_arms=len(PROBS), epsilon=0.1),
39+
UCB1(n_arms=len(PROBS)),
40+
ThompsonSamplingBeta(n_arms=len(PROBS)),
41+
]
42+
for agent in agents:
43+
stats = run_and_check(agent)
44+
pulls = stats['pulls']
45+
assert pulls[BEST] == pulls.max() # best arm most selected
46+
# At least 50% of pulls go to best arm after learning
47+
assert pulls[BEST] >= STEPS * 0.5

0 commit comments

Comments
 (0)