Skip to content

Commit 1e3a836

Browse files
Add AutoSP to DeepSpeed
--------- Signed-off-by: Neel Dani <neeldani98@gmail.com> Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
1 parent f88d0f8 commit 1e3a836

13 files changed

Lines changed: 850 additions & 44 deletions

File tree

deepspeed/compile/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
# DeepSpeed Team
55

6+
from typing import List, Optional, Literal
67
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
78

9+
PassName = Literal["z1", "z3", "autosp"]
10+
811

912
class CompileConfig(DeepSpeedConfigModel):
1013
""" Configure compile settings """
@@ -53,3 +56,6 @@ class CompileConfig(DeepSpeedConfigModel):
5356

5457
keep_all_input_tensors: bool = False
5558
""" Keep real values for all input tensors in InputStorage instead of using dummy values """
59+
60+
passes: Optional[List[PassName]] = None
61+
""" Composes different optimizations. """

deepspeed/compile/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
#########################################
7+
# AUTOSP
8+
#########################################
9+
AUTOSP_INPUT_ID_KEY = "input_id"
10+
AUTOSP_LABEL_ID_KEY = "label_id"
11+
AUTOSP_POSITION_ID_KEY = "position_id"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .all_to_all import all_to_all
7+
from . import sp_dp_registry
8+
9+
__all__ = ["all_to_all", "sp_dp_registry"]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
import deepspeed.comm as dist
8+
from .sp_dp_registry import get_group, is_setup, sp_size
9+
10+
11+
@torch.library.custom_op("autosp::all_to_all", mutates_args=())
12+
def all_to_all(
13+
input: torch.Tensor,
14+
scatter_idx: int,
15+
gather_idx: int,
16+
name: str,
17+
) -> torch.Tensor:
18+
"""
19+
All-to-all collective for SDPA tensors [B, N, S, H].
20+
21+
For QKV (scatter_idx=1, gather_idx=2):
22+
[B, N, S/P, H] -> [B, N/P, S, H]
23+
For O (scatter_idx=2, gather_idx=1):
24+
[B, N/P, S, H] -> [B, N, S/P, H]
25+
"""
26+
assert is_setup(), 'Incorrect initialization of SP/DP mesh.'
27+
B, dim1, dim2, H = input.shape
28+
gid = dist.get_rank() // sp_size()
29+
group = get_group(gid)
30+
31+
if scatter_idx == 1:
32+
N, local_S = dim1, dim2
33+
input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H)
34+
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
35+
36+
output = torch.empty_like(input_t)
37+
dist.all_to_all_single(output, input_t, group=group)
38+
39+
output = output.permute(1, 2, 0, 3, 4).contiguous()
40+
output = output.reshape(B, N // sp_size(), sp_size() * local_S, H)
41+
else:
42+
local_N, S = dim1, dim2
43+
input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H)
44+
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
45+
46+
output = torch.empty_like(input_t)
47+
dist.all_to_all_single(output, input_t, group=group)
48+
49+
output = output.permute(1, 0, 2, 3, 4).contiguous()
50+
output = output.reshape(B, sp_size() * local_N, S // sp_size(), H)
51+
52+
return output
53+
54+
55+
@torch.library.register_fake("autosp::all_to_all")
56+
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):
57+
B, dim1, dim2, H = input.shape
58+
if scatter_idx == 1:
59+
return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H)
60+
else:
61+
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)
62+
63+
64+
def _all_to_all_backward_setup(ctx, inputs, output):
65+
_, scatter_idx, gather_idx, name = inputs
66+
ctx.scatter_idx = gather_idx
67+
ctx.gather_idx = scatter_idx
68+
ctx.name = name + "_grad"
69+
70+
71+
def _all_to_all_backward(ctx, grad):
72+
return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None)
73+
74+
75+
torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import deepspeed.comm as dist
7+
8+
GROUP_REGISTRY = {} # int -> dist.ProcessGroup
9+
10+
11+
def register_groups(groups):
12+
"""groups: List[List[int]], e.g. [[0,1],[2,3]]"""
13+
for gid, ranks in enumerate(groups):
14+
if gid not in GROUP_REGISTRY:
15+
GROUP_REGISTRY[gid] = dist.new_group(ranks)
16+
17+
18+
def get_group(gid: int):
19+
return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group()
20+
21+
22+
def get_registry():
23+
return GROUP_REGISTRY
24+
25+
26+
def is_setup():
27+
return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False
28+
29+
30+
def extract_mesh_size(param_dict):
31+
sp_size = param_dict.get('sequence_parallel_size', 1)
32+
assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE'
33+
dp_size = dist.get_world_size() // sp_size
34+
35+
return sp_size, dp_size
36+
37+
38+
def sp_size():
39+
assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.'
40+
41+
return GROUP_REGISTRY['SP_SIZE']
42+
43+
44+
def dp_size():
45+
assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly'
46+
47+
return GROUP_REGISTRY['DP_SIZE']
48+
49+
50+
def populate_registry(SP_SIZE, DP_SIZE):
51+
""" Populate rank to SP/DP mesh index. """
52+
53+
if GROUP_REGISTRY.get('is_reg', False):
54+
return
55+
56+
group_listing = []
57+
offset = 0
58+
for _ in range(DP_SIZE):
59+
group_listing.append([i + offset for i in range(SP_SIZE)])
60+
offset += SP_SIZE
61+
62+
register_groups(group_listing)
63+
64+
## Extraneous metadata required for proper instatiation. ##
65+
GROUP_REGISTRY['SP_SIZE'] = SP_SIZE
66+
GROUP_REGISTRY['DP_SIZE'] = DP_SIZE
67+
GROUP_REGISTRY['is_reg'] = True

deepspeed/compile/fx.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
# DeepSpeed Team
55

6-
from typing import Callable, Any, List, Dict
6+
from typing import Callable, Any, List, Dict, Optional
77
from collections import defaultdict
88

99
import torch
10-
from torch.fx import Node, Graph
10+
from torch.fx import Node, Graph, GraphModule
1111

1212
from .util import get_last_uses
1313

@@ -138,3 +138,32 @@ def free_tensors(tensors: List[torch.Tensor]):
138138

139139
# Python version for debugging
140140
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
141+
142+
143+
def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]:
144+
for node in gm.graph.nodes:
145+
if node.name == name:
146+
return node
147+
return None
148+
149+
150+
def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]:
151+
return node.meta.get("val") or node.meta.get("example_value")
152+
153+
154+
def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]:
155+
input_id_node = None
156+
for node in gm.graph.nodes:
157+
# https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048
158+
tensor_dict = node.meta.get('tensor_dict')
159+
if tensor_dict and tensor_dict.get('tag') == tag:
160+
input_id_node = node
161+
break
162+
return input_id_node
163+
164+
165+
def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None):
166+
exclude = exclude or []
167+
to_replace = [u for u in node.users if u not in exclude]
168+
for user in to_replace:
169+
user.replace_input_with(node, replacement)

deepspeed/compile/init_sp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
from torch.fx import GraphModule
8+
from .passes.sp_compile import apply_autosp
9+
from .passes.long_context_checkpointing import register_long_context_checkpointing
10+
from .custom_ops.sp_dp_registry import extract_mesh_size
11+
12+
13+
def init_autosp(config):
14+
sp_size, dp_size = extract_mesh_size(config._param_dict)
15+
register_long_context_checkpointing()
16+
17+
def backend_fn(gm: GraphModule, real_inputs):
18+
apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size)
19+
return torch._inductor.compile(gm, real_inputs)
20+
21+
return backend_fn
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import inspect
7+
import textwrap
8+
import torch._functorch.partitioners as _partitioners
9+
10+
# The custom should_ban_recomputation to splice into solve_min_cut.
11+
# All names it references (aten, operator, config, op_types, min_cut_options,
12+
# is_materialized_backwards, get_aten_target, _size_of, fx, torch,
13+
# CheckpointPolicy) are either module-level in torch._functorch.partitioners
14+
# or local variables already in scope when this function executes inside
15+
# solve_min_cut.
16+
_CUSTOM_SHOULD_BAN = """\
17+
def should_ban_recomputation(node):
18+
\"\"\"Sequence-aware recomputation banning logic\"\"\"
19+
if node.op != "call_function":
20+
return False
21+
if node.target == operator.getitem:
22+
return False
23+
if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
24+
return True
25+
if config.recompute_views and op_types.is_view(node):
26+
return False
27+
if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
28+
return False
29+
30+
must_save_set = [
31+
aten.convolution,
32+
aten.convolution_backward,
33+
aten._scaled_dot_product_flash_attention,
34+
aten._scaled_dot_product_efficient_attention,
35+
aten._flash_attention_forward,
36+
aten._efficient_attention_forward,
37+
aten.upsample_bilinear2d,
38+
aten.native_dropout,
39+
aten.rand_like,
40+
aten.randn_like,
41+
]
42+
43+
if get_aten_target(node) in must_save_set:
44+
return True
45+
46+
def heuristic(node):
47+
if "val" in node.meta:
48+
if isinstance(node.meta["val"], torch.Tensor) and node.meta["val"].dim() >= 2:
49+
return node.meta["val"].shape[1] >= 4096
50+
return False
51+
52+
if min_cut_options.ban_if_not_in_allowlist:
53+
if not op_types.is_recomputable(node):
54+
return False
55+
56+
if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(node):
57+
if heuristic(node):
58+
return False
59+
return True
60+
61+
if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw:
62+
return False
63+
64+
if min_cut_options.ban_if_reduction:
65+
input_tensors_size = sum(
66+
_size_of(i) for i in node.args if isinstance(i, fx.Node)
67+
)
68+
output_size = _size_of(node)
69+
return output_size * 4 < input_tensors_size
70+
return False
71+
"""
72+
73+
74+
def register_long_context_checkpointing():
75+
"""Splice the custom should_ban_recomputation into solve_min_cut.
76+
77+
Uses inspect.getsource to extract solve_min_cut's source, replaces the
78+
original should_ban_recomputation with _CUSTOM_SHOULD_BAN, then execs the
79+
result directly in _partitioners.__dict__.
80+
81+
The exec'd function's __globals__ is the real partitioners module dict, so
82+
every other nested function (is_fusible, is_materialized_backwards,
83+
can_fuse_into_*, etc.) and every local/closure variable (op_types,
84+
min_cut_options, node_info, config, …) is exactly as in the original —
85+
nothing else changes.
86+
87+
Backward compatible: if solve_min_cut gains new heuristics in a future
88+
PyTorch version the exec automatically picks them up; only
89+
_CUSTOM_SHOULD_BAN needs to stay in sync with any changes to the
90+
original should_ban_recomputation signature/contract.
91+
"""
92+
src = inspect.getsource(_partitioners.solve_min_cut)
93+
lines = src.split('\n')
94+
95+
# Locate the original should_ban_recomputation and the function after it.
96+
start = next(
97+
i for i, l in enumerate(lines)
98+
if l.startswith(' def should_ban_recomputation(')
99+
)
100+
end = next(
101+
i for i, l in enumerate(lines)
102+
if i > start and l.startswith(' def ')
103+
)
104+
105+
# Indent the replacement to the nesting level inside solve_min_cut (4 spaces).
106+
replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ')
107+
108+
new_src = '\n'.join(lines[:start]) + '\n' + replacement + '\n'.join(lines[end:])
109+
exec(new_src, _partitioners.__dict__) # redefines _partitioners.solve_min_cut

0 commit comments

Comments
 (0)