@@ -347,6 +347,56 @@ def sliding_window_mask_mod(b, h, q_idx, kv_idx):
347347 )
348348
349349
350+ @gin .configurable
351+ class ClippedSlidingSinkAttention (FlexAttention ):
352+ """
353+ Sliding-window attention with optional attention sink and logit clipping.
354+ """
355+
356+ def __init__ (
357+ self ,
358+ causal : bool ,
359+ dropout : float ,
360+ window_size : int = gin .REQUIRED ,
361+ logit_clip : float = 0.0 ,
362+ sink_size : int = 0 ,
363+ sink_bias : float = 0.0 ,
364+ ):
365+ assert window_size > 0 , "window_size must be > 0"
366+ self .window_size = int (window_size )
367+ self .logit_clip = float (logit_clip ) if logit_clip is not None else 0.0
368+ self .sink_size = int (sink_size )
369+ self .sink_bias = float (sink_bias )
370+
371+ has_sink = self .sink_size > 0
372+ has_sink_bias = has_sink and (self .sink_bias != 0.0 )
373+ clip_active = self .logit_clip > 0.0
374+
375+ def sliding_window_with_sink_mask_mod (
376+ b : int , h : int , q_idx : int , kv_idx : int
377+ ) -> bool :
378+ dq = q_idx - kv_idx
379+ in_window = (dq >= 0 ) & (dq <= self .window_size )
380+ in_sink = (kv_idx < self .sink_size ) if has_sink else False
381+ return in_window | in_sink
382+
383+ def score_with_sink_and_clip (
384+ score : torch .Tensor , b : int , h : int , q_idx : int , kv_idx : int
385+ ) -> torch .Tensor :
386+ if has_sink_bias and kv_idx < self .sink_size :
387+ score = score + score .new_tensor (self .sink_bias )
388+ if clip_active :
389+ score = torch .clamp (score , - self .logit_clip , self .logit_clip )
390+ return score
391+
392+ super ().__init__ (
393+ score_mod = score_with_sink_and_clip ,
394+ mask_mod = sliding_window_with_sink_mask_mod ,
395+ causal = causal ,
396+ dropout = dropout ,
397+ )
398+
399+
350400@gin .configurable
351401class SigmaReparam (nn .Linear ):
352402 """SigmaReparam nn.Linear alternative.
0 commit comments