Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ pufferlib/resources/drive/output*.mp4

# Local TODO tracking
TODO.md
*.mp4
33 changes: 17 additions & 16 deletions pufferlib/config/ocean/adaptive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
package = ocean
env_name = puffer_adaptive_drive
policy_name = Drive
transformer_name = Transformer
; Changed from rnn_name
rnn_name = Recurrent

[vec]
num_workers = 16
num_envs = 16
batch_size = 2
batch_size = 1
; backend = Serial

[policy]
input_size = 128
; Increased from 64 for richer representations
input_size = 64
hidden_size = 256

[rnn]
input_size = 256
hidden_size = 256

[transformer]
Expand All @@ -23,14 +25,13 @@ num_layers = 2
; Number of transformer layers
num_heads = 4
; Number of attention heads (must divide hidden_size)
; context_length = 182
; k_scenarios (2) * scenario_length (91) = maximum attention span
; Transformer uses `horizon` from [train] section for attention span
dropout = 0.0
; Dropout (keep at 0 for RL stability initially)

[env]
num_agents = 1512
num_ego_agents = 756
num_agents = 1024
num_ego_agents = 512
; Options: discrete, continuous
action_type = discrete
; Options: classic, jerk
Expand All @@ -46,7 +47,7 @@ goal_radius = 2.0
; Max target speed in m/s for the agent to maintain towards the goal
goal_speed = 100.0
; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
goal_behavior = 0
goal_behavior = 2
; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals.
; Large numbers will select a goal point further away from the agent's current position.
goal_target_distance = 30.0
Expand Down Expand Up @@ -114,24 +115,22 @@ discount_weight_ub = 0.98
seed=42
total_timesteps = 2_000_000_000
anneal_lr = True
; Needs to be: num_agents * num_workers * context_window
; Needs to be: num_agents * num_workers * horizon
batch_size = auto
minibatch_size = 36400
; 400 * 91
max_minibatch_size = 36400
minibatch_multiplier = 400
policy_architecture = Transformer
; Matches scenario_length for buffer organization
bptt_horizon = 32
; Keep for backward compatibility
; Sequence length - overridden to k_scenarios * scenario_length for adaptive
horizon = 91
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_eps = 1e-8
clip_coef = 0.2
ent_coef = 0.005
gae_lambda = 0.95
gamma = 0.98
learning_rate = 0.0003
learning_rate = 0.003
; Reduced from 0.003 (transformers often need lower LR)
max_grad_norm = 1.0
prio_alpha = 0.85
Expand Down Expand Up @@ -193,6 +192,8 @@ human_replay_num_agents = 32
human_replay_num_rollouts = 100
; Number of maps to use for human replay evaluation
human_replay_num_maps = 100
; Number of maps to render for human replay (subset of eval maps)
human_replay_render_num_maps = 2

[sweep.train.learning_rate]
distribution = log_normal
Expand Down
55 changes: 29 additions & 26 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
package = ocean
env_name = puffer_drive
policy_name = Drive
rnn_name = Transformer
rnn_name = Recurrent

[vec]
num_workers = 16
Expand All @@ -14,9 +14,9 @@ batch_size = 2
input_size = 64
hidden_size = 256

; [rnn]
; input_size = 256
; hidden_size = 256
[rnn]
input_size = 256
hidden_size = 256

[transformer]
input_size = 256
Expand All @@ -25,14 +25,13 @@ num_layers = 2
; Number of transformer layers
num_heads = 4
; Number of attention heads (must divide hidden_size)
context_window = 32
; k_scenarios (2) * scenario_length (91) = maximum attention span
; Transformer uses `horizon` from [train] section for attention span
dropout = 0.0
; Dropout (keep at 0 for RL stability initially)

[env]
num_agents = 512
num_ego_agents = 512
num_agents = 1024
num_ego_agents = 1024
; Options: discrete, continuous
action_type = discrete
; Options: classic, jerk
Expand All @@ -47,7 +46,7 @@ goal_radius = 2.0
; Max target speed in m/s for the agent to maintain towards the goal
goal_speed = 100.0
; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop"
goal_behavior = 1
goal_behavior = 2
; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals.
; Large numbers will select a goal point further away from the agent's current position.
goal_target_distance = 30.0
Expand Down Expand Up @@ -112,16 +111,14 @@ discount_weight_ub = 0.80
[train]
seed=42
total_timesteps = 2_000_000_000
# learning_rate = 0.02
# gamma = 0.985
anneal_lr = True
; Needs to be: num_agents * num_workers * BPTT horizon
; Needs to be: num_agents * num_workers * horizon
batch_size = auto
minibatch_size = 32768
max_minibatch_size = 32768
; minibatch_size = 256
; max_minibatch_size = 256
bptt_horizon = 32
minibatch_multiplier = 400
; Sequence length for training - matches scenario_length for full episode context
horizon = 91
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_eps = 1e-8
Expand All @@ -130,17 +127,15 @@ ent_coef = 0.005
gae_lambda = 0.95
gamma = 0.98
learning_rate = 0.003
max_grad_norm = 1
prio_alpha = 0.8499999999999999
prio_beta0 = 0.8499999999999999
max_grad_norm = 1.0
prio_alpha = 0.85
prio_beta0 = 0.85
update_epochs = 1
vf_clip_coef = 0.1999999999999999
vf_coef = 2
vf_clip_coef = 0.2
vf_coef = 2.0
vtrace_c_clip = 1
vtrace_rho_clip = 1
checkpoint_interval = 100
use_transformer = True
context_window = 32
# Rendering options
render = True
render_interval = 100
Expand All @@ -158,7 +153,7 @@ zoom_in = True
render_map = none

[eval]
eval_interval = 1000
eval_interval = 100
; Path to dataset used for evaluation
map_dir = "resources/drive/binaries/training"
; Evaluation will run on the first num_maps maps in the map_dir directory
Expand All @@ -183,9 +178,17 @@ wosac_sanity_check = False
; Only return aggregate results across all scenes
wosac_aggregate_results = True
; If True, enable human replay evaluation (pair policy-controlled agent with human replays)
human_replay_eval = False
; Control only the self-driving car
human_replay_control_mode = "control_sdc_only"
human_replay_eval = True
; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent)
human_replay_control_mode = "control_vehicles"
; Number of agents in human replay evaluation environment
human_replay_num_agents = 32
; Number of rollouts for human replay evaluation
human_replay_num_rollouts = 100
; Number of maps to use for human replay evaluation
human_replay_num_maps = 100
; Number of maps to render for human replay (subset of eval maps)
human_replay_render_num_maps = 2

[sweep.train.learning_rate]
distribution = log_normal
Expand Down
26 changes: 13 additions & 13 deletions pufferlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(
hidden_size=128,
num_layers=4,
num_heads=8,
context_length=512,
horizon=512,
dropout=0.0,
):
"""Wraps your policy with a Transformer for temporal modeling.
Expand All @@ -220,15 +220,15 @@ def __init__(
hidden_size: Transformer hidden dimension
num_layers: Number of transformer layers
num_heads: Number of attention heads
context_length: Maximum sequence length to attend over
horizon: Maximum sequence length to attend over
dropout: Dropout probability
"""
super().__init__()
self.obs_shape = env.single_observation_space.shape
self.policy = policy
self.input_size = input_size
self.hidden_size = hidden_size
self.context_length = context_length
self.horizon = horizon
self.num_layers = num_layers
self.num_heads = num_heads
self.is_continuous = self.policy.is_continuous
Expand All @@ -240,7 +240,7 @@ def __init__(
self.input_projection = nn.Identity()

# Learnable positional embeddings
self.positional_embedding = nn.Parameter(torch.zeros(1, context_length, hidden_size))
self.positional_embedding = nn.Parameter(torch.zeros(1, horizon, hidden_size))
nn.init.normal_(self.positional_embedding, std=0.02)

# Transformer encoder
Expand Down Expand Up @@ -307,7 +307,7 @@ def forward_eval(self, observations, state):
hidden = self.input_projection(hidden)

if "transformer_context" not in state or state["transformer_context"] is None:
context = torch.zeros(B, self.context_length, self.hidden_size, device=device)
context = torch.zeros(B, self.horizon, self.hidden_size, device=device)
pos = torch.zeros(1, dtype=torch.long, device=device)
else:
context = state["transformer_context"]
Expand All @@ -316,24 +316,24 @@ def forward_eval(self, observations, state):
if (
context.shape[-1] != self.hidden_size
or context.shape[0] != B
or context.shape[1] != self.context_length
or context.shape[1] != self.horizon
):
context = torch.zeros(B, self.context_length, self.hidden_size, device=device)
context = torch.zeros(B, self.horizon, self.hidden_size, device=device)
pos = torch.zeros(1, dtype=torch.long, device=device)

write_idx = (pos % self.context_length).long()
write_idx = (pos % self.horizon).long()
context[:, write_idx, :] = hidden.unsqueeze(1)
pos = pos + 1

pos_embed = self.positional_embedding[:, : self.context_length]
pos_embed = self.positional_embedding[:, : self.horizon]
context_with_pos = context + pos_embed

causal_mask = self.get_causal_mask(self.context_length, device)
causal_mask = self.get_causal_mask(self.horizon, device)

output = self.transformer(context_with_pos, mask=causal_mask, is_causal=True)
output = self.output_norm(output)

read_idx = ((pos - 1) % self.context_length).long()
read_idx = ((pos - 1) % self.horizon).long()
hidden_out = output[:, read_idx, :].squeeze(1)

state["transformer_context"] = context
Expand Down Expand Up @@ -361,7 +361,7 @@ def forward(self, observations, state):
hidden = self.input_projection(hidden)

# Remove dynamic truncation - use clamp instead of if
T_actual = min(T, self.context_length) # Python int, fine
T_actual = min(T, self.horizon) # Python int, fine
if T_actual < T:
hidden = hidden[:, -T_actual:]
T = T_actual
Expand Down Expand Up @@ -390,7 +390,7 @@ def forward(self, observations, state):
values = values.view(B, T)

# Use Python int for context_len - no sync
context_len = min(T, self.context_length)
context_len = min(T, self.horizon)
state["hidden"] = hidden
state["transformer_context"] = hidden[:, -context_len:].detach()
state["transformer_position"] = torch.full((B,), context_len - 1, dtype=torch.long, device=device)
Expand Down
18 changes: 11 additions & 7 deletions pufferlib/ocean/benchmark/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def rollout(self, args, puffer_env, policy):
the policy is with (static) human partners.

Args:
args: Config dict with train settings (device, use_rnn, policy_architecture, etc.)
args: Config dict with train settings (device, rnn_name, etc.)
puffer_env: PufferLib environment wrapper
policy: Trained policy to evaluate

Expand All @@ -654,18 +654,22 @@ def rollout(self, args, puffer_env, policy):

obs, info = puffer_env.reset()

policy_architecture = args["train"].get("policy_architecture", "Recurrent")
k_scenarios = args["env"].get("k_scenarios", 1)

if policy_architecture == "Recurrent":
# Detect architecture from policy object
is_transformer = hasattr(policy, 'horizon') and hasattr(policy, 'transformer')
is_recurrent = hasattr(policy, 'lstm')

if is_recurrent:
state = dict(
lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device),
lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device),
)
elif policy_architecture == "Transformer":
context_length = args["train"].get("context_window", 182)
elif is_transformer:
# Get horizon from the policy (TransformerWrapper stores it)
horizon = policy.horizon
state = dict(
transformer_context=torch.zeros(num_agents, context_length, policy.hidden_size, device=device),
transformer_context=torch.zeros(num_agents, horizon, policy.hidden_size, device=device),
transformer_position=torch.zeros(1, dtype=torch.long, device=device),
)
else:
Expand All @@ -690,7 +694,7 @@ def rollout(self, args, puffer_env, policy):
obs, rewards, dones, truncs, info_list = puffer_env.step(action_np)

# Reset transformer context on mid-scenario terminations (not at scenario boundaries)
if policy_architecture == "Transformer":
if is_transformer:
is_last_step = time_idx == self.sim_steps - 1
if not is_last_step:
done_mask = dones | truncs
Expand Down
20 changes: 9 additions & 11 deletions pufferlib/ocean/drive/drive.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,18 @@ void test_drivenet() {
void demo() {

// Note: The settings below are hardcoded for demo purposes. Since the policy was
// trained with these exact settings, that changing them may lead to
// weird behavior.
// trained with these exact settings, changing them may lead to weird behavior.
Drive env = {
.human_agent_idx = 0,
.dynamics_model = conf.dynamics_model,
.reward_vehicle_collision = conf.reward_vehicle_collision,
.reward_offroad_collision = conf.reward_offroad_collision,
.reward_ade = conf.reward_ade,
.goal_radius = conf.goal_radius,
.dt = conf.dt,
.dynamics_model = CLASSIC,
.reward_vehicle_collision = -1.0f,
.reward_offroad_collision = -1.0f,
.goal_radius = 2.0f,
.dt = 0.1f,
.map_name = "resources/drive/binaries/training/map_000.bin",
.init_steps = conf.init_steps,
.collision_behavior = conf.collision_behavior,
.offroad_behavior = conf.offroad_behavior,
.init_steps = 0,
.collision_behavior = 0,
.offroad_behavior = 0,
};
allocate(&env);
c_reset(&env);
Expand Down
Loading
Loading