File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -73,6 +73,7 @@ def leaky_relu_filter(
7373 target_f0 : float = 1e-2 ,
7474 clip_weights_low : Optional [float ] = 1e-7 ,
7575 clip_weights_high : Optional [float ] = 10.0 ,
76+ add_binary : bool = True ,
7677) -> torch .Tensor :
7778 """Weights policy regression data using a leaky relu ramp with f(0)=target_f0.
7879
@@ -92,6 +93,8 @@ def leaky_relu_filter(
9293 weights = bias + F .leaky_relu (x , negative_slope = neg_slope )
9394 if clip_weights_low is not None or clip_weights_high is not None :
9495 weights = torch .clamp (weights , min = clip_weights_low , max = clip_weights_high )
96+ if add_binary :
97+ weights += binary_filter (adv )
9598 return weights
9699
97100
@@ -630,6 +633,8 @@ def masked_avg(x_, dim=0):
630633 stats [f"Q(s, a) (global mean, rescaled) gamma={ gamma :.3f} " ] = masked_avg (
631634 q_s_a_g , i
632635 )
636+ print ("here" )
637+ stats ["Q Sequence" ] = q_s_a_g
633638 stats [f"Q(s,a) (global mean, raw scale) gamma={ gamma :.3f} " ] = masked_avg (
634639 raw_q_s_a_g , i
635640 )
You can’t perform that action at this time.
0 commit comments