Skip to content

Commit 25a2797

Browse files
authored
Merge pull request #1356 from calmdown539/dev-postgresql
Add the implementations for the ScaledDotProductAttention
2 parents 34e85a7 + f9852e2 commit 25a2797

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,3 +444,40 @@ def _get_attn_mask(attn_mask, n_head):
444444
attn_mask_np = np.expand_dims(attn_mask_np, axis=1)
445445
attn_mask_np = np.broadcast_to(attn_mask_np, (batch_size, n_head, seq_q_len, seq_k_len))
446446
return tensor.from_numpy(attn_mask_np)
447+
448+
class ScaledDotProductAttention(layer.Layer):
449+
def __init__(self, d_model=512, n_head=8):
450+
super(ScaledDotProductAttention, self).__init__()
451+
self.d_k = d_model // n_head
452+
assert (
453+
self.d_k * n_head == d_model
454+
), "embed_dim must be divisible by num_heads"
455+
456+
def forward(self, query, key, value, attn_mask):
457+
"""
458+
Args:
459+
query: [batch_size, n_heads, len_q, d_k]
460+
key: [batch_size, n_heads, len_k, d_k]
461+
value: [batch_size, n_heads, len_v(=len_k), d_v]
462+
attn_mask: [batch_size, n_heads, seq_len, seq_len]
463+
Returns:
464+
"""
465+
K_trans = autograd.transpose(key, [0, 1, 3, 2])
466+
467+
# scores : [batch_size, n_heads, len_q, len_k]
468+
# query [batch_size, n_heads, len_q, d_k]
469+
# k^T [batch_size, n_heads, d_k, len_k]
470+
scores = matmul4d(query, K_trans)
471+
d_k_sqrt = Tensor(shape=(1,), requires_grad=False, stores_grad=False)
472+
d_k_sqrt.set_value(np.sqrt(self.d_k))
473+
scores = autograd.div(scores, d_k_sqrt)
474+
475+
mask_fill = Tensor(shape=attn_mask.shape, data=np.full(attn_mask.shape, -1e6, dtype=np.float32), requires_grad=False, stores_grad=False)
476+
attn_mask_np = tensor.to_numpy(attn_mask)
477+
scores = autograd.where(mask_fill, scores, attn_mask_np)
478+
479+
attn = autograd.softmax(scores, axis=-1)
480+
# context: [batch_size, n_heads, len_q, d_v]
481+
# attn: [batch_size, n_heads, len_q, len_k] value: [batch_size, n_heads, len_v(=len_k), d_v]
482+
context = matmul4d(attn, value)
483+
return context, attn

0 commit comments

Comments
 (0)