|
1 | | -def epsilon_greedy(*args, **kwargs): |
2 | | - return 0 |
| 1 | +from __future__ import annotations |
| 2 | +import numpy as np |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Optional, Dict, Any |
| 5 | + |
| 6 | +@dataclass |
| 7 | +class EpsilonGreedy: |
| 8 | + n_arms: int |
| 9 | + epsilon: float = 0.1 |
| 10 | + init: float = 0.0 |
| 11 | + def __post_init__(self): |
| 12 | + self.counts = np.zeros(self.n_arms, dtype=int) |
| 13 | + self.values = np.full(self.n_arms, float(self.init), dtype=float) |
| 14 | + self.rng = np.random.default_rng() |
| 15 | + def select_arm(self) -> int: |
| 16 | + if self.rng.random() < self.epsilon: |
| 17 | + return int(self.rng.integers(self.n_arms)) |
| 18 | + return int(np.argmax(self.values)) |
| 19 | + def update(self, arm: int, reward: float): |
| 20 | + self.counts[arm] += 1 |
| 21 | + n = self.counts[arm] |
| 22 | + self.values[arm] += (reward - self.values[arm]) / n |
| 23 | + |
| 24 | +@dataclass |
| 25 | +class UCB1: |
| 26 | + n_arms: int |
| 27 | + c: float = 2.0 |
| 28 | + def __post_init__(self): |
| 29 | + self.counts = np.zeros(self.n_arms, dtype=int) |
| 30 | + self.values = np.zeros(self.n_arms, dtype=float) |
| 31 | + self.total = 0 |
| 32 | + self.rng = np.random.default_rng() |
| 33 | + def select_arm(self) -> int: |
| 34 | + # pull each arm once |
| 35 | + for a in range(self.n_arms): |
| 36 | + if self.counts[a] == 0: |
| 37 | + return a |
| 38 | + ucb = self.values + self.c * np.sqrt(np.log(self.total) / self.counts) |
| 39 | + return int(np.argmax(ucb)) |
| 40 | + def update(self, arm: int, reward: float): |
| 41 | + self.total += 1 |
| 42 | + self.counts[arm] += 1 |
| 43 | + n = self.counts[arm] |
| 44 | + self.values[arm] += (reward - self.values[arm]) / n |
| 45 | + |
| 46 | +@dataclass |
| 47 | +class ThompsonSamplingBeta: |
| 48 | + n_arms: int |
| 49 | + a0: float = 1.0 |
| 50 | + b0: float = 1.0 |
| 51 | + def __post_init__(self): |
| 52 | + self.a = np.full(self.n_arms, self.a0, dtype=float) |
| 53 | + self.b = np.full(self.n_arms, self.b0, dtype=float) |
| 54 | + self.rng = np.random.default_rng() |
| 55 | + def select_arm(self) -> int: |
| 56 | + samples = self.rng.beta(self.a, self.b) |
| 57 | + return int(np.argmax(samples)) |
| 58 | + def update(self, arm: int, reward: float): |
| 59 | + self.a[arm] += reward |
| 60 | + self.b[arm] += 1 - reward |
| 61 | + |
| 62 | +def simulate(env, agent, steps: int, seed: Optional[int] = None) -> Dict[str, Any]: |
| 63 | + """Run interaction loop and return history stats.""" |
| 64 | + if seed is not None: |
| 65 | + try: |
| 66 | + agent.rng = np.random.default_rng(seed) |
| 67 | + except AttributeError: |
| 68 | + pass |
| 69 | + if hasattr(env, 'rng'): |
| 70 | + env.rng = np.random.default_rng(seed+1 if seed is not None else None) |
| 71 | + rewards = np.zeros(steps, dtype=float) |
| 72 | + pulls = np.zeros(env.k, dtype=int) |
| 73 | + choices = np.zeros(steps, dtype=int) |
| 74 | + for t in range(steps): |
| 75 | + a = agent.select_arm() |
| 76 | + r = env.pull(a) |
| 77 | + agent.update(a, r) |
| 78 | + rewards[t] = r |
| 79 | + pulls[a] += 1 |
| 80 | + choices[t] = a |
| 81 | + return { |
| 82 | + "avg_reward": float(rewards.mean()), |
| 83 | + "cum_reward": float(rewards.sum()), |
| 84 | + "pulls": pulls, |
| 85 | + "choices": choices, |
| 86 | + "rewards": rewards, |
| 87 | + } |
0 commit comments