33"""
44
55import itertools
6- from typing import Type , Optional , Tuple , Any , List
6+ from typing import Type , Optional , Tuple , Any , List , Iterable
77
88import torch
99from torch import nn
@@ -222,6 +222,7 @@ def __init__(
222222 use_multigamma : bool = True ,
223223 actor_type : Type [actor_critic .BaseActorHead ] = actor_critic .Actor ,
224224 critic_type : Type [actor_critic .BaseCriticHead ] = actor_critic .NCritics ,
225+ pass_obs_keys_to_actor : Optional [Iterable [str ]] = None ,
225226 ):
226227 super ().__init__ ()
227228 self .obs_space = obs_space
@@ -293,6 +294,7 @@ def __init__(
293294 self .target_actor = actor_type (** ac_kwargs )
294295 # full weight copy to targets
295296 self .hard_sync_targets ()
297+ self .pass_obs_keys_to_actor = pass_obs_keys_to_actor or []
296298
297299 @property
298300 def trainable_params (self ):
@@ -364,7 +366,10 @@ def get_actions(
364366 tstep_emb , time_idxs = time_idxs , hidden_state = hidden_state
365367 )
366368 # generate action distribution [batch, length, len(self.gammas), d_action]
367- action_dists = self .actor (traj_emb_t )
369+ action_dists = self .actor (
370+ traj_emb_t ,
371+ straight_from_obs = {k : obs [k ] for k in self .pass_obs_keys_to_actor },
372+ )
368373 if sample :
369374 actions = action_dists .sample ()
370375 else :
@@ -473,7 +478,7 @@ def forward(
473478 ## a ~ \pi(s) ##
474479 ################
475480 critic_loss = None
476- a_dist = self .actor (s_rep , log_dict = active_log_dict )
481+ a_dist = self .actor (s_rep , log_dict = active_log_dict , straight_from_obs = { k : batch . obs [ k ] for k in self . pass_obs_keys_to_actor } )
477482 a_agent = self ._sample_k_actions (a_dist , k = K_a )
478483 assert a_agent .shape == (K_a , B , L , G , D_action )
479484 if log_step :
@@ -742,6 +747,7 @@ def __init__(
742747 use_multigamma : bool = True ,
743748 actor_type : Type [actor_critic .BaseActorHead ] = actor_critic .Actor ,
744749 critic_type : Type [actor_critic .BaseCriticHead ] = actor_critic .NCriticsTwoHot ,
750+ pass_obs_keys_to_actor : Optional [Iterable [str ]] = None ,
745751 ):
746752 super ().__init__ (
747753 obs_space = obs_space ,
@@ -766,6 +772,7 @@ def __init__(
766772 popart = popart ,
767773 actor_type = actor_type ,
768774 critic_type = critic_type ,
775+ pass_obs_keys_to_actor = pass_obs_keys_to_actor ,
769776 )
770777
771778 def _sample_k_actions (self , dist , k : int ):
@@ -815,7 +822,7 @@ def forward(self, batch: Batch, log_step: bool):
815822 ################
816823 ## a ~ \pi(s) ##
817824 ################
818- a_dist = self .actor (s_rep , log_dict = active_log_dict )
825+ a_dist = self .actor (s_rep , log_dict = active_log_dict , straight_from_obs = { k : batch . obs [ k ] for k in self . pass_obs_keys_to_actor } )
819826 if self .discrete :
820827 a_dist = DiscreteLikeContinuous (a_dist )
821828 if log_step :
0 commit comments