88import numpy as np
99import torch
1010
11- from behavior_generation_lecture_python .mdp .policy import CategorialPolicy
11+ from behavior_generation_lecture_python .mdp .policy import CategoricalPolicy
1212
1313SIMPLE_MDP_DICT = {
1414 "states" : [1 , 2 ],
@@ -147,9 +147,9 @@ def get_transitions_with_probabilities(
147147 def sample_next_state (self , state , action ) -> Any :
148148 """Randomly sample the next state given the current state and taken action."""
149149 if self .is_terminal (state ):
150- return ValueError ("No next state for terminal states." )
150+ raise ValueError ("No next state for terminal states." )
151151 if action is None :
152- return ValueError ("Action must not be None." )
152+ raise ValueError ("Action must not be None." )
153153 prob_per_transition = self .get_transitions_with_probabilities (state , action )
154154 num_actions = len (prob_per_transition )
155155 choice = np .random .choice (
@@ -431,6 +431,7 @@ def q_learning(
431431 alpha : float ,
432432 epsilon : float ,
433433 iterations : int ,
434+ seed : Optional [int ] = None ,
434435 return_history : Optional [bool ] = False ,
435436) -> Union [QTable , List [QTable ]]:
436437 """Derive a value estimate for state-action pairs by means of Q learning.
@@ -441,22 +442,24 @@ def q_learning(
441442 epsilon: Exploration-exploitation threshold. A random action is taken with
442443 probability epsilon, the best action otherwise.
443444 iterations: Number of iterations.
445+ seed: Random seed for reproducibility (default: None).
444446 return_history: Whether to return the whole history of value estimates
445447 instead of just the final estimate.
446448
447449 Returns:
448450 The final value estimate, if return_history is false. The
449451 history of value estimates as list, if return_history is true.
450452 """
453+ if seed is not None :
454+ np .random .seed (seed )
455+
451456 q_table = {}
452457 for state in mdp .get_states ():
453458 for action in mdp .get_actions (state ):
454459 q_table [(state , action )] = 0.0
455460 q_table_history = [q_table .copy ()]
456461 state = mdp .initial_state
457462
458- np .random .seed (1337 )
459-
460463 for _ in range (iterations ):
461464 # available actions:
462465 avail_actions = mdp .get_actions (state )
@@ -528,14 +531,15 @@ def mean_episode_length(self) -> float:
528531def policy_gradient (
529532 * ,
530533 mdp : MDP ,
531- policy : CategorialPolicy ,
534+ policy : CategoricalPolicy ,
532535 lr : float = 1e-2 ,
533536 iterations : int = 50 ,
534537 batch_size : int = 5000 ,
538+ seed : Optional [int ] = None ,
535539 return_history : bool = False ,
536540 use_random_init_state : bool = False ,
537541 verbose : bool = True ,
538- ) -> Union [List [CategorialPolicy ], CategorialPolicy ]:
542+ ) -> Union [List [CategoricalPolicy ], CategoricalPolicy ]:
539543 """Train a paramterized policy using vanilla policy gradient.
540544
541545 Adapted from: https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py
@@ -556,6 +560,7 @@ def policy_gradient(
556560 lr: Learning rate.
557561 iterations: Number of iterations.
558562 batch_size: Number of samples generated for each policy update.
563+ seed: Random seed for reproducibility (default: None).
559564 return_history: Whether to return the whole history of value estimates
560565 instead of just the final estimate.
561566 use_random_init_state: bool, if the agent should be initialized randomly.
@@ -565,8 +570,9 @@ def policy_gradient(
565570 The final policy, if return_history is false. The
566571 history of policies as list, if return_history is true.
567572 """
568- np .random .seed (1337 )
569- torch .manual_seed (1337 )
573+ if seed is not None :
574+ np .random .seed (seed )
575+ torch .manual_seed (seed )
570576
571577 # add untrained model to model_checkpoints
572578 model_checkpoints = [deepcopy (policy )]
@@ -650,7 +656,7 @@ def policy_gradient(
650656 return policy
651657
652658
653- def derive_deterministic_policy (mdp : MDP , policy : CategorialPolicy ) -> Dict [Any , Any ]:
659+ def derive_deterministic_policy (mdp : MDP , policy : CategoricalPolicy ) -> Dict [Any , Any ]:
654660 """Compute the best policy for an MDP given the stochastic policy.
655661
656662 Args:
0 commit comments