-
Notifications
You must be signed in to change notification settings - Fork 324
Expand file tree
/
Copy pathbase_layer_weight.py
More file actions
43 lines (37 loc) · 1.37 KB
/
base_layer_weight.py
File metadata and controls
43 lines (37 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import numpy as np
import threading
from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
class BaseLayerWeight:
def __init__(self):
self.tp_rank_ = get_current_rank_in_dp()
self.tp_world_size_ = get_dp_world_size()
self.lock = threading.Lock()
def load_hf_weights(self, weights):
"""
load weights
"""
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, BaseWeight):
attr.load_hf_weights(weights)
def init_static_params(self):
"""
design for some static init params, many model dont need do this.
"""
pass
def verify_load(self):
"""
verify all load is ok
"""
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseWeight):
if hasattr(self, "layer_num_"):
layer_num = self.layer_num_
else:
layer_num = None
assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails."
def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())