Skip to content

Commit b8450e2

Browse files
authored
Merge pull request #91 from UT-Austin-RPL/retro
clipped sliding window attention w/ sink, clamped linear adv filter
2 parents c766f60 + 6a13c13 commit b8450e2

4 files changed

Lines changed: 83 additions & 3 deletions

File tree

amago/agent.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,37 @@ def binary_filter(adv: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
6464
return adv > threshold
6565

6666

67+
@gin.configurable
68+
def leaky_relu_filter(
69+
adv: torch.Tensor,
70+
beta: float = 2.0,
71+
tau: float = 1e-2,
72+
neg_slope: float = 0.05,
73+
target_f0: float = 1e-2,
74+
clip_weights_low: Optional[float] = 1e-7,
75+
clip_weights_high: Optional[float] = 10.0,
76+
) -> torch.Tensor:
77+
"""Weights policy regression data using a leaky relu ramp with f(0)=target_f0.
78+
79+
Args:
80+
adv: Tensor of advantages (Batch, Length, Gammas, 1)
81+
82+
Keyword Args:
83+
beta: Positive scale controlling slope.
84+
tau: Advantage hinge location for switching from leak to main slope.
85+
neg_slope: Slope for advantages below tau.
86+
target_f0: Desired weight at adv=0 (before clipping).
87+
clip_weights_low: If provided, clip output weights below this value. Defaults to None.
88+
clip_weights_high: If provided, clip output weights above this value. Defaults to None.
89+
"""
90+
bias = target_f0 + neg_slope * tau / beta
91+
x = (adv - tau) / beta
92+
weights = bias + F.leaky_relu(x, negative_slope=neg_slope)
93+
if clip_weights_low is not None or clip_weights_high is not None:
94+
weights = torch.clamp(weights, min=clip_weights_low, max=clip_weights_high)
95+
return weights
96+
97+
6798
@gin.configurable
6899
def exp_filter(
69100
adv: torch.Tensor,
@@ -289,8 +320,6 @@ def __init__(
289320
self.critics = critic_type(**ac_kwargs, num_critics=num_critics)
290321
self.target_critics = critic_type(**ac_kwargs, num_critics=num_critics)
291322
self.maximized_critics = critic_type(**ac_kwargs, num_critics=num_critics)
292-
if self.multibinary:
293-
ac_kwargs["cont_dist_kind"] = "multibinary"
294323
self.actor = actor_type(**ac_kwargs)
295324
self.target_actor = actor_type(**ac_kwargs)
296325
# full weight copy to targets

amago/cli_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ def make_experiment_learn_only(experiment: amago.Experiment) -> amago.Experiment
464464
experiment.parallel_actors = 1
465465
experiment.always_save_latest = True
466466
experiment.always_load_latest = False
467-
experiment.has_dset_edit_rights = True
468467
return experiment
469468

470469

@@ -488,6 +487,7 @@ def make_experiment_collect_only(experiment: amago.Experiment) -> amago.Experime
488487
experiment.epochs = max(experiment.epochs, 1_000_000)
489488
# do not delete anything from the collection process
490489
experiment.has_dset_edit_rights = False
490+
experiment.init_dsets()
491491
return experiment
492492

493493

amago/nets/traj_encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def reset_hidden_state(
532532
hidden_state.reset(idxs=dones)
533533
return hidden_state
534534

535+
@torch.compile
535536
def forward(
536537
self,
537538
seq: torch.Tensor,

amago/nets/transformer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,56 @@ def sliding_window_mask_mod(b, h, q_idx, kv_idx):
347347
)
348348

349349

350+
@gin.configurable
351+
class ClippedSlidingSinkAttention(FlexAttention):
352+
"""
353+
Sliding-window attention with optional attention sink and logit clipping.
354+
"""
355+
356+
def __init__(
357+
self,
358+
causal: bool,
359+
dropout: float,
360+
window_size: int = gin.REQUIRED,
361+
logit_clip: float = 0.0,
362+
sink_size: int = 0,
363+
sink_bias: float = 0.0,
364+
):
365+
assert window_size > 0, "window_size must be > 0"
366+
self.window_size = int(window_size)
367+
self.logit_clip = float(logit_clip) if logit_clip is not None else 0.0
368+
self.sink_size = int(sink_size)
369+
self.sink_bias = float(sink_bias)
370+
371+
has_sink = self.sink_size > 0
372+
has_sink_bias = has_sink and (self.sink_bias != 0.0)
373+
clip_active = self.logit_clip > 0.0
374+
375+
def sliding_window_with_sink_mask_mod(
376+
b: int, h: int, q_idx: int, kv_idx: int
377+
) -> bool:
378+
dq = q_idx - kv_idx
379+
in_window = (dq >= 0) & (dq <= self.window_size)
380+
in_sink = (kv_idx < self.sink_size) if has_sink else False
381+
return in_window | in_sink
382+
383+
def score_with_sink_and_clip(
384+
score: torch.Tensor, b: int, h: int, q_idx: int, kv_idx: int
385+
) -> torch.Tensor:
386+
if has_sink_bias and kv_idx < self.sink_size:
387+
score = score + score.new_tensor(self.sink_bias)
388+
if clip_active:
389+
score = torch.clamp(score, -self.logit_clip, self.logit_clip)
390+
return score
391+
392+
super().__init__(
393+
score_mod=score_with_sink_and_clip,
394+
mask_mod=sliding_window_with_sink_mask_mod,
395+
causal=causal,
396+
dropout=dropout,
397+
)
398+
399+
350400
@gin.configurable
351401
class SigmaReparam(nn.Linear):
352402
"""SigmaReparam nn.Linear alternative.

0 commit comments

Comments
 (0)