Skip to content

Commit 7ed412f

Browse files
authored
Merge pull request #89 from UT-Austin-RPL/bugs
add option to pass keys from observation directly to actor net
2 parents e59b517 + 4e395b3 commit 7ed412f

3 files changed

Lines changed: 31 additions & 10 deletions

File tree

amago/agent.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import itertools
6-
from typing import Type, Optional, Tuple, Any, List
6+
from typing import Type, Optional, Tuple, Any, List, Iterable
77

88
import torch
99
from torch import nn
@@ -222,6 +222,7 @@ def __init__(
222222
use_multigamma: bool = True,
223223
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
224224
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCritics,
225+
pass_obs_keys_to_actor: Optional[Iterable[str]] = None,
225226
):
226227
super().__init__()
227228
self.obs_space = obs_space
@@ -293,6 +294,7 @@ def __init__(
293294
self.target_actor = actor_type(**ac_kwargs)
294295
# full weight copy to targets
295296
self.hard_sync_targets()
297+
self.pass_obs_keys_to_actor = pass_obs_keys_to_actor or []
296298

297299
@property
298300
def trainable_params(self):
@@ -364,7 +366,10 @@ def get_actions(
364366
tstep_emb, time_idxs=time_idxs, hidden_state=hidden_state
365367
)
366368
# generate action distribution [batch, length, len(self.gammas), d_action]
367-
action_dists = self.actor(traj_emb_t)
369+
action_dists = self.actor(
370+
traj_emb_t,
371+
straight_from_obs={k: obs[k] for k in self.pass_obs_keys_to_actor},
372+
)
368373
if sample:
369374
actions = action_dists.sample()
370375
else:
@@ -473,7 +478,7 @@ def forward(
473478
## a ~ \pi(s) ##
474479
################
475480
critic_loss = None
476-
a_dist = self.actor(s_rep, log_dict=active_log_dict)
481+
a_dist = self.actor(s_rep, log_dict=active_log_dict, straight_from_obs={k : batch.obs[k] for k in self.pass_obs_keys_to_actor})
477482
a_agent = self._sample_k_actions(a_dist, k=K_a)
478483
assert a_agent.shape == (K_a, B, L, G, D_action)
479484
if log_step:
@@ -742,6 +747,7 @@ def __init__(
742747
use_multigamma: bool = True,
743748
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
744749
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCriticsTwoHot,
750+
pass_obs_keys_to_actor: Optional[Iterable[str]] = None,
745751
):
746752
super().__init__(
747753
obs_space=obs_space,
@@ -766,6 +772,7 @@ def __init__(
766772
popart=popart,
767773
actor_type=actor_type,
768774
critic_type=critic_type,
775+
pass_obs_keys_to_actor=pass_obs_keys_to_actor,
769776
)
770777

771778
def _sample_k_actions(self, dist, k: int):
@@ -815,7 +822,7 @@ def forward(self, batch: Batch, log_step: bool):
815822
################
816823
## a ~ \pi(s) ##
817824
################
818-
a_dist = self.actor(s_rep, log_dict=active_log_dict)
825+
a_dist = self.actor(s_rep, log_dict=active_log_dict, straight_from_obs={k : batch.obs[k] for k in self.pass_obs_keys_to_actor})
819826
if self.discrete:
820827
a_dist = DiscreteLikeContinuous(a_dist)
821828
if log_step:

amago/nets/actor_critic.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def __init__(
4949
self.actions_differentiable = self.policy_dist.actions_differentiable
5050

5151
def forward(
52-
self, state: torch.Tensor, log_dict: Optional[dict] = None
52+
self,
53+
state: torch.Tensor,
54+
log_dict: Optional[dict] = None,
55+
straight_from_obs: Optional[dict[str, torch.Tensor]] = None,
5356
) -> pyd.Distribution:
5457
"""Compute an action distribution from a state representation.
5558
@@ -61,7 +64,9 @@ def forward(
6164
(e.g. `Discrete` or `TanhGaussian`). Always a pytorch distribution (e.g., `Categorical`)
6265
where sampled actions would have shape (Batch, Length, Gammas, action_dim).
6366
"""
64-
dist_params = self.actor_network_forward(state=state, log_dict=log_dict)
67+
dist_params = self.actor_network_forward(
68+
state=state, log_dict=log_dict, straight_from_obs=straight_from_obs
69+
)
6570
assert dist_params.ndim == 4
6671
assert dist_params.shape[-2:] == (
6772
self.num_gammas,
@@ -71,7 +76,10 @@ def forward(
7176

7277
@abstractmethod
7378
def actor_network_forward(
74-
self, state: torch.Tensor, log_dict: Optional[dict] = None
79+
self,
80+
state: torch.Tensor,
81+
log_dict: Optional[dict] = None,
82+
straight_from_obs: Optional[dict[str, torch.Tensor]] = None,
7583
) -> torch.Tensor:
7684
raise NotImplementedError
7785

@@ -127,7 +135,10 @@ def __init__(
127135
)
128136

129137
def actor_network_forward(
130-
self, state: torch.Tensor, log_dict: Optional[dict] = None
138+
self,
139+
state: torch.Tensor,
140+
log_dict: Optional[dict] = None,
141+
straight_from_obs: Optional[dict[str, torch.Tensor]] = None,
131142
) -> torch.Tensor:
132143
dist_params = self.base(state)
133144
dist_params = rearrange(
@@ -216,7 +227,10 @@ def forward(self, x):
216227

217228
@torch.compile
218229
def actor_network_forward(
219-
self, state: torch.Tensor, log_dict: Optional[dict] = None
230+
self,
231+
state: torch.Tensor,
232+
log_dict: Optional[dict] = None,
233+
straight_from_obs: Optional[dict[str, torch.Tensor]] = None,
220234
) -> torch.Tensor:
221235
B, L, D = state.shape
222236
x = self.inp(state)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="amago",
5-
version="3.1.0",
5+
version="3.1.1",
66
author="Jake Grigsby",
77
author_email="grigsby@cs.utexas.edu",
88
license="MIT",

0 commit comments

Comments
 (0)