-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_lookup.py
More file actions
61 lines (52 loc) · 1.59 KB
/
embedding_lookup.py
File metadata and controls
61 lines (52 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Embedding lookup.
# Sources
* https://github.com/NVIDIA/Megatron-LM/blob/076972e37420b5325c5fe06e7131be7d96f05b53/megatron/core/tensor_parallel/layers.py#L171
"""
import torch
import torch.distributed
import multiprocessing
import os
def main(rank, world_size):
if world_size > 1:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
torch.distributed.init_process_group(
"gloo", rank=rank, world_size=world_size)
batch_size = 8
seq_len = 16
vocab_size = 32
embed_dim = 64
torch.manual_seed(0)
weight = torch.randn(vocab_size, embed_dim)
x = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
vocab_size_per_rank = vocab_size // world_size
start = rank * vocab_size_per_rank
end = start + vocab_size_per_rank
# Typically, we would initialize a weight matrix of this shape,
# but slicing makes it easier to test.
weight = weight[start:end, :]
mask = ((x >= start) & (x < end))
x -= start
x *= mask
embeddings = weight[x]
embeddings[~mask] = 0
torch.distributed.all_reduce(embeddings)
torch.manual_seed(0)
weight = torch.randn(vocab_size, embed_dim)
x = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
expected_embeddings = weight[x]
assert torch.allclose(embeddings, expected_embeddings)
if world_size > 1:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
world_size = 2
if world_size == 1:
main(0, 1)
else:
processes = []
for rank in range(world_size):
p = multiprocessing.Process(target=main, args=(rank, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()