|
1 | | -from __future__ import annotations |
2 | | -import os |
3 | | -import numpy as np |
4 | | -import matplotlib.pyplot as plt |
5 | | -from .bandits import BernoulliBandit |
6 | | -from .epsilon_greedy import EpsilonGreedy |
7 | | -from .ucb import UCB1 |
8 | | -from .thompson import ThompsonSamplingBernoulli |
| 1 | +import argparse, os, numpy as np, matplotlib.pyplot as plt |
| 2 | +from .epsilon_greedy import run as run_eps |
| 3 | +from .ucb import run as run_ucb |
| 4 | +from .thompson import run as run_ts |
9 | 5 |
|
10 | | -def run_algorithm(env, algo, T: int, seed: int) -> dict: |
11 | | - rng = np.random.default_rng(seed) |
12 | | - rewards = np.zeros(T, dtype=float) |
13 | | - regret = np.zeros(T, dtype=float) |
14 | | - for t in range(T): |
15 | | - a = algo.select_arm() |
16 | | - r = env.pull(a, rng) |
17 | | - algo.update(a, r) |
18 | | - rewards[t] = r |
19 | | - regret[t] = env.pseudo_regret(a) |
20 | | - return { |
21 | | - "rewards": rewards, |
22 | | - "cum_rewards": np.cumsum(rewards), |
23 | | - "regret": regret, |
24 | | - "cum_regret": np.cumsum(regret), |
25 | | - } |
| 6 | +def parse_args(): |
| 7 | + p = argparse.ArgumentParser() |
| 8 | + p.add_argument("--K", type=int, default=10) |
| 9 | + p.add_argument("--T", type=int, default=5000) |
| 10 | + p.add_argument("--trials", type=int, default=50) |
| 11 | + p.add_argument("--eps", type=float, default=0.1) |
| 12 | + p.add_argument("--c", type=float, default=1.0) |
| 13 | + p.add_argument("--seed", type=int, default=123) |
| 14 | + p.add_argument("--outdir", type=str, default="ch3_multi_armed_bandits/plots") |
| 15 | + return p.parse_args() |
| 16 | + |
| 17 | +def make_true_means(K, rng): return rng.uniform(0.1, 0.9, size=K) |
26 | 18 |
|
27 | | -def average_over_runs(env, algo_ctor, T: int, n_runs: int, base_seed: int = 0) -> dict: |
28 | | - cum_regrets = [] |
29 | | - for run in range(n_runs): |
30 | | - algo = algo_ctor() |
31 | | - result = run_algorithm(env, algo, T, seed=base_seed + run) |
32 | | - cum_regrets.append(result["cum_regret"]) |
33 | | - cum_regrets = np.array(cum_regrets) |
34 | | - mean = cum_regrets.mean(axis=0) |
35 | | - se = cum_regrets.std(axis=0, ddof=1) / np.sqrt(n_runs) |
36 | | - return {"mean": mean, "se": se} |
| 19 | +def run_all(true_means, T, trials, eps, c, seed): |
| 20 | + rng = np.random.default_rng(seed) |
| 21 | + avg_regret = {"eps": np.zeros(T), "ucb": np.zeros(T), "ts": np.zeros(T)} |
| 22 | + for _ in range(trials): |
| 23 | + s = int(rng.integers(0, 2**31-1)) |
| 24 | + avg_regret["eps"] += run_eps(true_means, eps, T, s)["cum_regret"] |
| 25 | + avg_regret["ucb"] += run_ucb(true_means, c, T, s)["cum_regret"] |
| 26 | + avg_regret["ts"] += run_ts(true_means, T, s)["cum_regret"] |
| 27 | + for k in avg_regret: avg_regret[k] /= trials |
| 28 | + return avg_regret |
37 | 29 |
|
38 | | -def plot_regret(curves: dict, title: str, fname: str | None): |
39 | | - fig, ax = plt.subplots() |
40 | | - for label, stats in curves.items(): |
41 | | - ax.plot(stats["mean"], label=label) |
42 | | - ax.set_xlabel("Time") |
43 | | - ax.set_ylabel("Average cumulative pseudo-regret") |
44 | | - ax.set_title(title) |
45 | | - ax.legend() |
46 | | - if fname: |
47 | | - out_dir = os.path.dirname(fname) |
48 | | - if out_dir and not os.path.exists(out_dir): |
49 | | - os.makedirs(out_dir, exist_ok=True) |
50 | | - fig.savefig(fname, bbox_inches="tight") |
51 | | - else: |
52 | | - plt.show() |
| 30 | +def plot(xs, series, ylabel, title, outpath): |
| 31 | + plt.figure() |
| 32 | + for label,y in series: plt.plot(xs,y,label=label) |
| 33 | + plt.xlabel("Time"); plt.ylabel(ylabel); plt.title(title); plt.legend() |
| 34 | + os.makedirs(os.path.dirname(outpath), exist_ok=True) |
| 35 | + plt.savefig(outpath, dpi=300); plt.close() |
53 | 36 |
|
54 | 37 | def main(): |
55 | | - probs = np.array([0.2, 0.25, 0.3, 0.35, 0.5]) |
56 | | - env = BernoulliBandit(probs=probs) |
57 | | - T = 2000 |
58 | | - n_runs = 200 |
59 | | - curves = {} |
60 | | - curves["ε-greedy(0.10)"] = average_over_runs(env, lambda: EpsilonGreedy(env.K, 0.10), T, n_runs, 123) |
61 | | - curves["ε-greedy(0.01)"] = average_over_runs(env, lambda: EpsilonGreedy(env.K, 0.01), T, n_runs, 223) |
62 | | - curves["UCB1(c=0.5)"] = average_over_runs(env, lambda: UCB1(env.K, c=0.5), T, n_runs, 323) |
63 | | - curves["Thompson (Beta-Bernoulli)"] = average_over_runs(env, lambda: ThompsonSamplingBernoulli(env.K), T, n_runs, 423) |
64 | | - here = os.path.dirname(__file__) |
65 | | - out_path = os.path.join(here, "plots", "regret_bernoulli.png") |
66 | | - plot_regret(curves, "Multi-Armed Bandits: Average Cumulative Pseudo-Regret", out_path) |
67 | | - print(f"Saved plot to {out_path}") |
68 | | - |
69 | | -if __name__ == "__main__": |
70 | | - main() |
| 38 | + a = parse_args() |
| 39 | + true_means = make_true_means(a.K, np.random.default_rng(a.seed)) |
| 40 | + xs = np.arange(1, a.T+1) |
| 41 | + reg = run_all(true_means,a.T,a.trials,a.eps,a.c,a.seed) |
| 42 | + plot(xs,[("ε-Greedy",reg["eps"]),("UCB1",reg["ucb"]),("Thompson",reg["ts"])], |
| 43 | + "Cumulative Regret","Regret vs Time",os.path.join(a.outdir,"regret.png")) |
| 44 | +if __name__=="__main__": main() |
0 commit comments