Skip to content

Commit e50cfa9

Browse files
Housekeeping MDP code (#30)
Co-authored-by: Annika Bätz <annika.baetz@kit.edu>
1 parent 5c4b816 commit e50cfa9

4 files changed

Lines changed: 80 additions & 28 deletions

File tree

notebooks/mdp_policy_gradient.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"source": [
99
"import os\n",
1010
"\n",
11-
"from behavior_generation_lecture_python.mdp.policy import CategorialPolicy\n",
11+
"from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy\n",
1212
"from behavior_generation_lecture_python.utils.grid_plotting import (\n",
1313
" make_plot_policy_step_function,\n",
1414
")\n",
@@ -47,7 +47,7 @@
4747
"metadata": {},
4848
"outputs": [],
4949
"source": [
50-
"policy = CategorialPolicy(\n",
50+
"policy = CategoricalPolicy(\n",
5151
" sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],\n",
5252
" actions=list(grid_mdp.actions),\n",
5353
")"
@@ -146,7 +146,7 @@
146146
"metadata": {},
147147
"outputs": [],
148148
"source": [
149-
"policy = CategorialPolicy(\n",
149+
"policy = CategoricalPolicy(\n",
150150
" sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],\n",
151151
" actions=list(highway_mdp.actions),\n",
152152
")"

src/behavior_generation_lecture_python/mdp/mdp.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import torch
1010

11-
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
11+
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
1212

1313
SIMPLE_MDP_DICT = {
1414
"states": [1, 2],
@@ -147,9 +147,9 @@ def get_transitions_with_probabilities(
147147
def sample_next_state(self, state, action) -> Any:
148148
"""Randomly sample the next state given the current state and taken action."""
149149
if self.is_terminal(state):
150-
return ValueError("No next state for terminal states.")
150+
raise ValueError("No next state for terminal states.")
151151
if action is None:
152-
return ValueError("Action must not be None.")
152+
raise ValueError("Action must not be None.")
153153
prob_per_transition = self.get_transitions_with_probabilities(state, action)
154154
num_actions = len(prob_per_transition)
155155
choice = np.random.choice(
@@ -431,6 +431,7 @@ def q_learning(
431431
alpha: float,
432432
epsilon: float,
433433
iterations: int,
434+
seed: Optional[int] = None,
434435
return_history: Optional[bool] = False,
435436
) -> Union[QTable, List[QTable]]:
436437
"""Derive a value estimate for state-action pairs by means of Q learning.
@@ -441,22 +442,24 @@ def q_learning(
441442
epsilon: Exploration-exploitation threshold. A random action is taken with
442443
probability epsilon, the best action otherwise.
443444
iterations: Number of iterations.
445+
seed: Random seed for reproducibility (default: None).
444446
return_history: Whether to return the whole history of value estimates
445447
instead of just the final estimate.
446448
447449
Returns:
448450
The final value estimate, if return_history is false. The
449451
history of value estimates as list, if return_history is true.
450452
"""
453+
if seed is not None:
454+
np.random.seed(seed)
455+
451456
q_table = {}
452457
for state in mdp.get_states():
453458
for action in mdp.get_actions(state):
454459
q_table[(state, action)] = 0.0
455460
q_table_history = [q_table.copy()]
456461
state = mdp.initial_state
457462

458-
np.random.seed(1337)
459-
460463
for _ in range(iterations):
461464
# available actions:
462465
avail_actions = mdp.get_actions(state)
@@ -528,14 +531,15 @@ def mean_episode_length(self) -> float:
528531
def policy_gradient(
529532
*,
530533
mdp: MDP,
531-
policy: CategorialPolicy,
534+
policy: CategoricalPolicy,
532535
lr: float = 1e-2,
533536
iterations: int = 50,
534537
batch_size: int = 5000,
538+
seed: Optional[int] = None,
535539
return_history: bool = False,
536540
use_random_init_state: bool = False,
537541
verbose: bool = True,
538-
) -> Union[List[CategorialPolicy], CategorialPolicy]:
542+
) -> Union[List[CategoricalPolicy], CategoricalPolicy]:
539543
"""Train a paramterized policy using vanilla policy gradient.
540544
541545
Adapted from: https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py
@@ -556,6 +560,7 @@ def policy_gradient(
556560
lr: Learning rate.
557561
iterations: Number of iterations.
558562
batch_size: Number of samples generated for each policy update.
563+
seed: Random seed for reproducibility (default: None).
559564
return_history: Whether to return the whole history of value estimates
560565
instead of just the final estimate.
561566
use_random_init_state: bool, if the agent should be initialized randomly.
@@ -565,8 +570,9 @@ def policy_gradient(
565570
The final policy, if return_history is false. The
566571
history of policies as list, if return_history is true.
567572
"""
568-
np.random.seed(1337)
569-
torch.manual_seed(1337)
573+
if seed is not None:
574+
np.random.seed(seed)
575+
torch.manual_seed(seed)
570576

571577
# add untrained model to model_checkpoints
572578
model_checkpoints = [deepcopy(policy)]
@@ -650,7 +656,7 @@ def policy_gradient(
650656
return policy
651657

652658

653-
def derive_deterministic_policy(mdp: MDP, policy: CategorialPolicy) -> Dict[Any, Any]:
659+
def derive_deterministic_policy(mdp: MDP, policy: CategoricalPolicy) -> Dict[Any, Any]:
654660
"""Compute the best policy for an MDP given the stochastic policy.
655661
656662
Args:

src/behavior_generation_lecture_python/mdp/policy.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This module contains the CategoricalPolicy implementation."""
22

3-
from typing import List, Type
3+
from typing import Any, List, Optional, Type
44

55
import torch
66
from torch import nn
@@ -23,36 +23,78 @@ def multi_layer_perceptron(
2323
return mlp
2424

2525

26-
class CategorialPolicy:
27-
def __init__(self, sizes: List[int], actions: List):
26+
class CategoricalPolicy:
27+
"""A categorical policy parameterized by a neural network."""
28+
29+
def __init__(
30+
self, sizes: List[int], actions: List[Any], seed: Optional[int] = None
31+
) -> None:
32+
"""Initialize the categorical policy.
33+
34+
Args:
35+
sizes: List of layer sizes for the MLP.
36+
actions: List of available actions.
37+
seed: Random seed for reproducibility (default: None).
38+
"""
2839
assert sizes[-1] == len(actions)
29-
torch.manual_seed(1337)
40+
if seed is not None:
41+
torch.manual_seed(seed)
3042
self.net = multi_layer_perceptron(sizes=sizes)
3143
self.actions = actions
3244
self._actions_tensor = torch.tensor(actions, dtype=torch.long).view(
3345
len(actions), -1
3446
)
3547

36-
def _get_distribution(self, state: torch.Tensor):
37-
"""Calls the model and returns a categorial distribution over the actions."""
48+
def _get_distribution(self, state: torch.Tensor) -> Categorical:
49+
"""Calls the model and returns a categorical distribution over the actions.
50+
51+
Args:
52+
state: The current state tensor.
53+
54+
Returns:
55+
A categorical distribution over actions.
56+
"""
3857
logits = self.net(state)
3958
return Categorical(logits=logits)
4059

41-
def get_action(self, state: torch.Tensor, deterministic: bool = False):
42-
"""Returns an action sample for the given state"""
60+
def get_action(self, state: torch.Tensor, deterministic: bool = False) -> Any:
61+
"""Returns an action sample for the given state.
62+
63+
Args:
64+
state: The current state tensor.
65+
deterministic: If True, return the most likely action.
66+
67+
Returns:
68+
The selected action.
69+
"""
4370
policy = self._get_distribution(state)
4471
if deterministic:
4572
return self.actions[policy.mode.item()]
4673
return self.actions[policy.sample().item()]
4774

48-
def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor):
49-
"""Returns the log-probability for taking the action, when being the given state"""
75+
def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
76+
"""Returns the log-probability for taking the action, when being in the given state.
77+
78+
Args:
79+
states: Batch of state tensors.
80+
actions: Batch of action tensors.
81+
82+
Returns:
83+
Log-probabilities of the actions.
84+
"""
5085
return self._get_distribution(states).log_prob(
5186
self._get_action_id_from_action(actions)
5287
)
5388

54-
def _get_action_id_from_action(self, actions: torch.Tensor):
55-
"""Returns the indices of the passed actions in self.actions"""
89+
def _get_action_id_from_action(self, actions: torch.Tensor) -> torch.Tensor:
90+
"""Returns the indices of the passed actions in self.actions.
91+
92+
Args:
93+
actions: Batch of action tensors.
94+
95+
Returns:
96+
Tensor of action indices.
97+
"""
5698
reshaped_actions = actions.unsqueeze(1).expand(
5799
-1, self._actions_tensor.size(0), -1
58100
)

tests/test_mdp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
random_action,
1515
value_iteration,
1616
)
17-
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
17+
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
1818

1919

2020
def test_init_mdp():
@@ -151,22 +151,26 @@ def test_q_learning(return_history):
151151
alpha=0.1,
152152
epsilon=0.1,
153153
iterations=10000,
154+
seed=1337,
154155
return_history=return_history,
155156
)
156157

157158

158159
@pytest.mark.parametrize("return_history", (True, False))
159160
def test_policy_gradient(return_history):
160161
mdp = GridMDP(**GRID_MDP_DICT)
161-
pol = CategorialPolicy(
162-
sizes=[len(mdp.initial_state), 32, len(mdp.actions)], actions=list(mdp.actions)
162+
pol = CategoricalPolicy(
163+
sizes=[len(mdp.initial_state), 32, len(mdp.actions)],
164+
actions=list(mdp.actions),
165+
seed=1337,
163166
)
164167
assert policy_gradient(
165168
mdp=mdp,
166169
policy=pol,
167170
lr=1e2,
168171
iterations=5,
169172
batch_size=5000,
173+
seed=1337,
170174
return_history=return_history,
171175
verbose=False,
172176
)

0 commit comments

Comments
 (0)