@@ -444,3 +444,40 @@ def _get_attn_mask(attn_mask, n_head):
444444 attn_mask_np = np .expand_dims (attn_mask_np , axis = 1 )
445445 attn_mask_np = np .broadcast_to (attn_mask_np , (batch_size , n_head , seq_q_len , seq_k_len ))
446446 return tensor .from_numpy (attn_mask_np )
447+
448+ class ScaledDotProductAttention (layer .Layer ):
449+ def __init__ (self , d_model = 512 , n_head = 8 ):
450+ super (ScaledDotProductAttention , self ).__init__ ()
451+ self .d_k = d_model // n_head
452+ assert (
453+ self .d_k * n_head == d_model
454+ ), "embed_dim must be divisible by num_heads"
455+
456+ def forward (self , query , key , value , attn_mask ):
457+ """
458+ Args:
459+ query: [batch_size, n_heads, len_q, d_k]
460+ key: [batch_size, n_heads, len_k, d_k]
461+ value: [batch_size, n_heads, len_v(=len_k), d_v]
462+ attn_mask: [batch_size, n_heads, seq_len, seq_len]
463+ Returns:
464+ """
465+ K_trans = autograd .transpose (key , [0 , 1 , 3 , 2 ])
466+
467+ # scores : [batch_size, n_heads, len_q, len_k]
468+ # query [batch_size, n_heads, len_q, d_k]
469+ # k^T [batch_size, n_heads, d_k, len_k]
470+ scores = matmul4d (query , K_trans )
471+ d_k_sqrt = Tensor (shape = (1 ,), requires_grad = False , stores_grad = False )
472+ d_k_sqrt .set_value (np .sqrt (self .d_k ))
473+ scores = autograd .div (scores , d_k_sqrt )
474+
475+ mask_fill = Tensor (shape = attn_mask .shape , data = np .full (attn_mask .shape , - 1e6 , dtype = np .float32 ), requires_grad = False , stores_grad = False )
476+ attn_mask_np = tensor .to_numpy (attn_mask )
477+ scores = autograd .where (mask_fill , scores , attn_mask_np )
478+
479+ attn = autograd .softmax (scores , axis = - 1 )
480+ # context: [batch_size, n_heads, len_q, d_v]
481+ # attn: [batch_size, n_heads, len_q, len_k] value: [batch_size, n_heads, len_v(=len_k), d_v]
482+ context = matmul4d (attn , value )
483+ return context , attn
0 commit comments