Skip to content

Commit 274278d

Browse files
authored
Merge pull request #94 from UT-Austin-RPL/discret_dist_type
leaky relu filter change
2 parents 1e3bfed + 2dd4666 commit 274278d

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

amago/agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def leaky_relu_filter(
7373
target_f0: float = 1e-2,
7474
clip_weights_low: Optional[float] = 1e-7,
7575
clip_weights_high: Optional[float] = 10.0,
76+
add_binary: bool = True,
7677
) -> torch.Tensor:
7778
"""Weights policy regression data using a leaky relu ramp with f(0)=target_f0.
7879
@@ -92,6 +93,8 @@ def leaky_relu_filter(
9293
weights = bias + F.leaky_relu(x, negative_slope=neg_slope)
9394
if clip_weights_low is not None or clip_weights_high is not None:
9495
weights = torch.clamp(weights, min=clip_weights_low, max=clip_weights_high)
96+
if add_binary:
97+
weights += binary_filter(adv)
9598
return weights
9699

97100

@@ -630,6 +633,8 @@ def masked_avg(x_, dim=0):
630633
stats[f"Q(s, a) (global mean, rescaled) gamma={gamma:.3f}"] = masked_avg(
631634
q_s_a_g, i
632635
)
636+
print("here")
637+
stats["Q Sequence"] = q_s_a_g
633638
stats[f"Q(s,a) (global mean, raw scale) gamma={gamma:.3f}"] = masked_avg(
634639
raw_q_s_a_g, i
635640
)

0 commit comments

Comments
 (0)