Skip to content

Commit 75268a4

Browse files
author
kip-cxj
committed
add statelessprocessgroup to extend collective library
1 parent 009082d commit 75268a4

6 files changed

Lines changed: 878 additions & 7 deletions

File tree

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .base import (
2+
Distributed,
3+
DistributedProcessGroup,
4+
all_gather_object,
5+
all_reduce,
6+
barrier,
7+
broadcast,
8+
destroy_process_group,
9+
init_process_group,
10+
is_initialized,
11+
new_group,
12+
use_backend,
13+
)
14+
15+
16+
__all__ = [
17+
"Distributed",
18+
"DistributedProcessGroup",
19+
"all_gather_object",
20+
"all_reduce",
21+
"barrier",
22+
"broadcast",
23+
"destroy_process_group",
24+
"init_process_group",
25+
"is_initialized",
26+
"new_group",
27+
"use_backend",
28+
]
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import importlib
2+
import io
3+
import pickle
4+
from abc import ABC, abstractmethod
5+
from datetime import timedelta
6+
from typing import Any, Protocol
7+
8+
import torch
9+
import torch.distributed as torch_dist
10+
11+
12+
class CommunicatorProtocol(Protocol):
13+
def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
14+
15+
16+
class CommGroup:
17+
def __init__(self, comm_handle: int, ranks: list[int]):
18+
self._comm = comm_handle
19+
self._ranks = ranks
20+
21+
@property
22+
def handle(self) -> int:
23+
return self._comm
24+
25+
@property
26+
def ranks(self) -> list[int]:
27+
return self._ranks
28+
29+
30+
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
31+
32+
33+
class Distributed(ABC):
34+
@abstractmethod
35+
def init_process_group(
36+
self,
37+
host: str,
38+
port: int,
39+
rank: int,
40+
world_size: int,
41+
timeout: timedelta,
42+
**kwargs,
43+
):
44+
raise NotImplementedError
45+
46+
@abstractmethod
47+
def destroy_process_group(
48+
self,
49+
group: DistributedProcessGroup | None = None,
50+
):
51+
raise NotImplementedError
52+
53+
@abstractmethod
54+
def is_initialized(self) -> bool:
55+
raise NotImplementedError
56+
57+
@abstractmethod
58+
def all_gather_object(
59+
self,
60+
object_list: list[Any],
61+
obj: Any,
62+
group: DistributedProcessGroup | None = None,
63+
):
64+
raise NotImplementedError
65+
66+
@abstractmethod
67+
def all_reduce(
68+
self,
69+
tensor: torch.Tensor,
70+
op: torch_dist.ReduceOp.RedOpType,
71+
group: DistributedProcessGroup | None = None,
72+
**kwargs,
73+
):
74+
raise NotImplementedError
75+
76+
@abstractmethod
77+
def broadcast(
78+
self,
79+
tensor: torch.Tensor,
80+
src: int,
81+
group: DistributedProcessGroup | None = None,
82+
**kwargs,
83+
):
84+
raise NotImplementedError
85+
86+
@abstractmethod
87+
def barrier(
88+
self,
89+
group: DistributedProcessGroup | None = None,
90+
**kwargs,
91+
):
92+
raise NotImplementedError
93+
94+
@abstractmethod
95+
def new_group(
96+
self,
97+
ranks: list[int],
98+
**kwargs,
99+
):
100+
raise NotImplementedError
101+
102+
103+
class TorchBackend(Distributed):
104+
def init_process_group(
105+
self,
106+
rank: int,
107+
world_size: int,
108+
store: torch.distributed.TCPStore,
109+
**kwargs,
110+
):
111+
backend = kwargs.get("backend", "nccl")
112+
timeout = kwargs.get("timeout", timedelta(minutes=10))
113+
114+
torch.distributed.init_process_group(
115+
backend=backend,
116+
world_size=world_size,
117+
rank=rank,
118+
timeout=timeout,
119+
store=store,
120+
)
121+
122+
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
123+
torch_dist.destroy_process_group(group)
124+
125+
def is_initialized(self) -> bool:
126+
return torch_dist.is_initialized()
127+
128+
def all_gather_object(
129+
self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
130+
):
131+
torch_dist.all_gather_object(object_list, obj, group)
132+
133+
def all_reduce(
134+
self,
135+
tensor: torch.Tensor,
136+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
137+
group: DistributedProcessGroup | None = None,
138+
**kwargs,
139+
):
140+
torch_dist.all_reduce(tensor, op, group, **kwargs)
141+
142+
def broadcast(
143+
self,
144+
tensor: torch.Tensor,
145+
src: int = 0,
146+
group: DistributedProcessGroup | None = None,
147+
**kwargs,
148+
):
149+
torch_dist.broadcast(tensor, src, group, **kwargs)
150+
151+
def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
152+
torch_dist.barrier(group, **kwargs)
153+
154+
def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
155+
return torch_dist.new_group(ranks, **kwargs)
156+
157+
158+
# specific device instance
159+
_BACKEND_INSTANCE: Distributed = TorchBackend()
160+
161+
_pickler = pickle.Pickler
162+
_unpickler = pickle.Unpickler
163+
164+
165+
def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
166+
f = io.BytesIO()
167+
_pickler(f).dump(obj)
168+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
169+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
170+
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
171+
return byte_tensor, local_size
172+
173+
174+
def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
175+
tensor = tensor.cpu()
176+
buf = tensor.numpy().tobytes()[:tensor_size]
177+
return _unpickler(io.BytesIO(buf)).load()
178+
179+
180+
def _flatten_for_scatter_gather(
181+
tensor_list: list[torch.Tensor], copy: bool = False
182+
) -> torch.Tensor:
183+
if not tensor_list:
184+
raise RuntimeError("Received an empty list.")
185+
t = tensor_list[0]
186+
buffer_shape = [len(tensor_list)] + list(t.shape)
187+
188+
buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
189+
if copy:
190+
for i, tensor in enumerate(tensor_list):
191+
buffer[i].copy_(tensor)
192+
return buffer
193+
194+
195+
def _common_all_gather_object(
196+
comm: CommunicatorProtocol,
197+
device: torch.device,
198+
world_size: int,
199+
object_list: list[Any],
200+
object: Any,
201+
):
202+
input_tensor, local_size = _object_to_tensor(object, device)
203+
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
204+
comm.all_gather(object_sizes_tensor, local_size)
205+
object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
206+
max_object_size = int(max(object_size_list).item())
207+
input_tensor.resize_(max_object_size)
208+
coalesced_output_tensor = torch.empty(
209+
max_object_size * world_size, dtype=torch.uint8, device=device
210+
)
211+
212+
comm.all_gather(coalesced_output_tensor, input_tensor)
213+
output_tensors = [
214+
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
215+
for i in range(world_size)
216+
]
217+
for i, tensor in enumerate(output_tensors):
218+
tensor = tensor.type(torch.uint8)
219+
tensor_size = object_size_list[i]
220+
object_list[i] = _tensor_to_object(tensor, tensor_size)
221+
222+
223+
def use_backend(backend: str | None):
224+
global _BACKEND_INSTANCE
225+
226+
if not backend:
227+
return
228+
229+
mapping = {
230+
"vllm_nccl": ".nccl.DistributedNccl",
231+
"vllm_hccl": ".hccl.DistributedHccl",
232+
}
233+
if backend not in mapping:
234+
raise ValueError(f"Unsupported custom backend: {backend}")
235+
236+
module_path, class_name = mapping[backend].rsplit(".", 1)
237+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
238+
backend_class = getattr(module, class_name)
239+
_BACKEND_INSTANCE = backend_class()
240+
241+
242+
def init_process_group(
243+
rank: int,
244+
world_size: int,
245+
store: torch.distributed.TCPStore,
246+
**kwargs,
247+
):
248+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
249+
250+
251+
def destroy_process_group(group: DistributedProcessGroup | None = None):
252+
_BACKEND_INSTANCE.destroy_process_group(group)
253+
254+
255+
def is_initialized() -> bool:
256+
return _BACKEND_INSTANCE.is_initialized()
257+
258+
259+
def all_gather_object(
260+
object_list: list[Any],
261+
obj: Any,
262+
group: DistributedProcessGroup | None = None,
263+
):
264+
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
265+
266+
267+
def all_reduce(
268+
tensor: torch.Tensor,
269+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
270+
group: DistributedProcessGroup | None = None,
271+
**kwargs,
272+
):
273+
_BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
274+
275+
276+
def broadcast(
277+
tensor: torch.Tensor,
278+
src: int = 0,
279+
group: DistributedProcessGroup | None = None,
280+
**kwargs,
281+
):
282+
_BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
283+
284+
285+
def barrier(group: DistributedProcessGroup | None = None, **kwargs):
286+
_BACKEND_INSTANCE.barrier(group, **kwargs)
287+
288+
289+
def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
290+
return _BACKEND_INSTANCE.new_group(ranks, **kwargs)

0 commit comments

Comments
 (0)