Skip to content

Commit 5336436

Browse files
authored
fix python api compat issues (#9)
1 parent dedaff1 commit 5336436

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

flashinfer/comm/cuda_ipc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def create_shared_buffer(
203203
pointer = cudart.cudaMalloc(size_in_bytes)
204204
handle = cudart.cudaIpcGetMemHandle(pointer)
205205
if group is None:
206-
group = dist.group.WORLD
206+
group = dist.get_group()
207207
# world_size = dist.get_world_size(group=group)
208208
rank = dist.get_rank(group=group)
209209
# handles = [None] * world_size

flashinfer/comm/trtllm_ar.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,22 @@ def trtllm_custom_all_reduce(
767767
def _should_use_oneshot(
768768
token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int
769769
) -> bool:
770-
comm_size_mb = (
771-
token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024
772-
)
770+
DTYPE_SIZE_MAP = {
771+
torch.float16: 2,
772+
torch.bfloat16: 2,
773+
torch.float32: 4,
774+
torch.float64: 8,
775+
torch.int8: 1,
776+
torch.int16: 2,
777+
torch.int32: 4,
778+
torch.int64: 8,
779+
torch.uint8: 1,
780+
torch.bool: 1,
781+
torch.complex64: 8,
782+
torch.complex128: 16,
783+
}
784+
itemsize = DTYPE_SIZE_MAP[dtype]
785+
comm_size_mb = token_num * hidden_dim * 2 * world_size * itemsize / 1024 / 1024
773786
return comm_size_mb <= _use_oneshot_heuristics[world_size]
774787

775788

0 commit comments

Comments
 (0)