Skip to content

Commit 907611f

Browse files
SW publisherJenkins
authored andcommitted
deepspeed-fork content for 1.18.0
Signed-off-by: SW publisher <sw_publisher@habana-labs.com>
1 parent d254d75 commit 907611f

156 files changed

Lines changed: 12028 additions & 768 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 0 additions & 89 deletions
This file was deleted.

accelerator/hpu_accelerator.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# DeepSpeed Team
55

6+
import functools
67
import os
78
import pkgutil
89
import importlib
@@ -17,6 +18,7 @@ def __init__(self):
1718
self._name = 'hpu'
1819
self._communication_backend_name = 'hccl'
1920
self._compile_backend = "hpu_backend"
21+
self.apply_hpu_workarounds()
2022
try:
2123
import habana_frameworks.torch.hpu as hpu
2224
hpu.setDeterministic(True)
@@ -27,6 +29,15 @@ def __init__(self):
2729

2830
self.fp16_supported = None
2931

32+
def apply_hpu_workarounds(self):
33+
34+
def update_wa_env_var(key, value):
35+
if key not in os.environ.keys():
36+
os.environ[key] = value
37+
38+
update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0")
39+
update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0")
40+
3041
# Device APIs
3142
def is_synchronized_device(self):
3243
return False
@@ -41,9 +52,8 @@ def handles_memory_backpressure(self):
4152
return True
4253

4354
def device_name(self, device_index=None):
44-
if device_index is None:
45-
return 'hpu'
46-
return 'hpu:{}'.format(device_index)
55+
# ignoring device_index.
56+
return 'hpu'
4757

4858
def device(self, device_index=None):
4959
return torch.device(self.device_name(device_index))
@@ -194,33 +204,34 @@ def replay_graph(self, graph):
194204
return
195205

196206
# Tensor operations
207+
# TODO(SW-192865): Remove WA for tensor wrappers.
197208
@property
198209
def BFloat16Tensor(self):
199-
return self.hpu.BFloat16Tensor
210+
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')
200211

201212
@property
202213
def ByteTensor(self):
203-
return self.hpu.ByteTensor
214+
return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')
204215

205216
@property
206217
def DoubleTensor(self):
207-
return self.hpu.DoubleTensor
218+
return functools.partial(torch.tensor, dtype=torch.double, device='hpu')
208219

209220
@property
210221
def FloatTensor(self):
211-
return self.hpu.FloatTensor
222+
return functools.partial(torch.tensor, dtype=torch.float, device='hpu')
212223

213224
@property
214225
def HalfTensor(self):
215-
return self.hpu.HalfTensor
226+
return functools.partial(torch.tensor, dtype=torch.half, device='hpu')
216227

217228
@property
218229
def IntTensor(self):
219-
return self.hpu.IntTensor
230+
return functools.partial(torch.tensor, dtype=torch.int, device='hpu')
220231

221232
@property
222233
def LongTensor(self):
223-
return self.hpu.LongTensor
234+
return functools.partial(torch.tensor, dtype=torch.long, device='hpu')
224235

225236
def pin_memory(self, tensor, align_bytes=1):
226237
return tensor.pin_memory(self.device())
@@ -289,6 +300,14 @@ def get_op_builder(self, class_name):
289300
else:
290301
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
291302

303+
#shall be removed once moving to torch.compile
304+
def wrap_in_hpu_graph(self, module):
305+
if self.hpu.is_lazy():
306+
module = self.hpu.wrap_in_hpu_graph(module)
307+
else:
308+
print("Warning: hpu graphs in eager mode is not supported, ignoring")
309+
return module
310+
292311
def build_extension(self):
293312
from torch.utils.cpp_extension import BuildExtension
294313
return BuildExtension
@@ -297,7 +316,12 @@ def export_envs(self):
297316
return []
298317

299318
def visible_devices_envs(self):
300-
return ['HABANA_VISIBLE_MODULES']
319+
# TODO SW-195658: remove WA to not return HABANA_VISIBLE_MODULES once SW-195657 is resolved
320+
# Current way deepspeed set this env var is not applicable with all HPU instances
321+
# User has to follow instructions in:
322+
# https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html
323+
# keeping CUDA_VISIBLE_DEVICES
324+
return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES']
301325

302326
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
303327
for env in self.visible_devices_envs():

build.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
+hpu.synapse.v1.18.0

csrc/transformer/inference/csrc/pt_binding.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,15 +452,16 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
452452
unsigned layer_id,
453453
unsigned num_layers,
454454
at::Tensor& alibi,
455-
float rope_theta)
455+
float rope_theta,
456+
bool is_prompt,
457+
std::optional<at::Tensor> token_idx,
458+
std::optional<at::Tensor> position_ids)
456459
{
457460
unsigned bsz = query_key_value.size(0);
458461
unsigned seq_len = query_key_value.size(1);
459462
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
460463
unsigned hidden_dim = heads * k;
461464

462-
bool is_prompt = (seq_len > 1);
463-
464465
if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
465466
unsigned soft_len = InferenceContext::Instance().current_tokens();
466467

@@ -2028,7 +2029,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
20282029
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
20292030
m.def("dequantize_" #_name, \
20302031
&ds_dequantize<_dtype>, \
2031-
"DeepSpeed dequantize with " #_name " (CUDA)")
2032+
"DeepSpeed dequantize with " #_name " (CUDA)");
20322033

20332034
DEF_OPS(fp32, float);
20342035
DEF_OPS(fp16, __half);

deepspeed/checkpoint/zero_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _strip_tensor_paddings(self, sd):
105105
if group_paddings[key] == 0:
106106
continue
107107
for state_name, state_value in group_state.items():
108-
if state_name != "step" and torch.is_tensor(state_value):
108+
if state_name != "step" and torch.is_tensor(state_value) and state_value.dim():
109109
raw_length = state_value.numel() - group_paddings[key]
110110
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
111111
else:

deepspeed/comm/torch.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,15 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
146146
world_size=world_size)
147147
self.using_mpi = torch.distributed.get_backend() == 'mpi'
148148

149-
@compiler.disable
150149
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
151150
op = self._reduce_op(op)
152151
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
153152

154-
@compiler.disable
155153
def inference_all_reduce(self, tensor, op, group=None):
156154
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
157155
op = self._reduce_op(op)
158156
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
159157

160-
@compiler.disable
161158
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
162159
""" proxy func to torch.distributed.all_reduce_coalesced,
163160
which is included in PyTorch 1.13 and above
@@ -168,15 +165,13 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
168165
op = self._reduce_op(op)
169166
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
170167

171-
@compiler.disable
172168
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
173169
if DS_COMM_REDUCE_OFF:
174170
if int(os.getenv('RANK', '0')) == 0:
175171
utils.logger.warning("REDUCE is OFF")
176172
return Noop()
177173
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
178174

179-
@compiler.disable
180175
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
181176
if DS_COMM_REDUCE_SCATTER_OFF:
182177
if int(os.getenv('RANK', '0')) == 0:
@@ -189,7 +184,6 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
189184
group=group,
190185
async_op=async_op)
191186

192-
@compiler.disable
193187
def broadcast(self, tensor, src, group=None, async_op=False):
194188
if DS_COMM_BROADCAST_OFF:
195189
if int(os.getenv('RANK', '0')) == 0:
@@ -198,7 +192,6 @@ def broadcast(self, tensor, src, group=None, async_op=False):
198192
else:
199193
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
200194

201-
@compiler.disable
202195
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
203196
if DS_COMM_ALL_GATHER_OFF:
204197
if int(os.getenv('RANK', '0')) == 0:
@@ -207,15 +200,13 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
207200
else:
208201
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
209202

210-
@compiler.disable
211203
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
212204
if self.has_all_gather_into_tensor():
213205
return self.all_gather_function(output_tensor=output_tensor,
214206
input_tensor=input_tensor,
215207
group=group,
216208
async_op=async_op)
217209

218-
@compiler.disable
219210
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
220211
if DS_COMM_ALL_GATHER_OFF:
221212
if int(os.getenv('RANK', '0')) == 0:
@@ -233,7 +224,6 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
233224
"please consider upgrading your pytorch installation.")
234225
pass
235226

236-
@compiler.disable
237227
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
238228
""""""
239229
assert len(output_tensors) == len(input_tensors), ""
@@ -257,7 +247,6 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
257247
else:
258248
reqs[-1].wait()
259249

260-
@compiler.disable
261250
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
262251
if self.has_reduce_scatter_tensor():
263252
return self.reduce_scatter_function(output_tensor,
@@ -271,7 +260,6 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
271260
"please consider upgrading your pytorch installation.")
272261
pass
273262

274-
@compiler.disable
275263
def all_to_all_single(self,
276264
output,
277265
input,
@@ -286,49 +274,40 @@ def all_to_all_single(self,
286274
group=group,
287275
async_op=async_op)
288276

289-
@compiler.disable
290277
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
291278
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)
292279

293-
@compiler.disable
294280
def send(self, tensor, dst, group=None, tag=0):
295281
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
296282

297-
@compiler.disable
298283
def recv(self, tensor, src=None, group=None, tag=0):
299284
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
300285

301-
@compiler.disable
302286
def isend(self, tensor, dst, group=None, tag=0):
303287
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
304288

305-
@compiler.disable
306289
def irecv(self, tensor, src=None, group=None, tag=0):
307290
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
308291

309-
@compiler.disable
310292
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
311293
return torch.distributed.gather(tensor=tensor,
312294
gather_list=gather_list,
313295
dst=dst,
314296
group=group,
315297
async_op=async_op)
316298

317-
@compiler.disable
318299
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
319300
return torch.distributed.scatter(tensor=tensor,
320301
scatter_list=scatter_list,
321302
src=src,
322303
group=group,
323304
async_op=async_op)
324305

325-
@compiler.disable
326306
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
327307
if group is None:
328308
group = torch.distributed.GroupMember.WORLD
329309
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
330310

331-
@compiler.disable
332311
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
333312
if group is None:
334313
group = torch.distributed.GroupMember.WORLD

0 commit comments

Comments
 (0)