Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 58 additions & 62 deletions xtuner/v1/rl/trainer/update_weighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def _init_update_weighter(self):
self.worker_server_urls_status: dict[str, bool] = {}

self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict()
self._ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3)
self._update_params_ipc_tensor = None
self._default_ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3)
self._ipc_tensor_bytes_dict_by_dtype: dict[torch.dtype, int] = {}
self._update_params_ipc_tensor_dict_by_dtype: dict[torch.dtype, torch.Tensor] = {}
self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None
self._update_params_ipc_event = None
self._sglang_disagg_group: dist.ProcessGroup | None = None
self._sglang_disagg_group_name: str | None = None
Expand Down Expand Up @@ -187,7 +189,8 @@ def _update_weights_colocated(self):
self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True)
else:
self._update_weights_hf_generator(final_update=True)
self._update_params_ipc_tensor = None
self._update_params_ipc_tensor_dict_by_dtype = {}
self._last_update_params_ipc_tensor_dtype = None
self._update_params_ipc_event = None

DEVICE_MODULE.empty_cache()
Expand Down Expand Up @@ -508,6 +511,52 @@ def _init_external_process_group(
def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype):
return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype)

def _build_lmdeploy_flattened_tensor_data(self, state_dict: dict, flattened_tensor_bucket_cls) -> dict:
# LMDeploy flattened buckets require all tensors in one bucket to share a dtype.
state_dict_dtype = state_dict[next(iter(state_dict))].dtype
update_params_ipc_tensor = self._update_params_ipc_tensor_dict_by_dtype.get(state_dict_dtype, None)
state_dict_bytes = self._compute_state_dict_bytes(state_dict)
ipc_tensor_bytes = self._ipc_tensor_bytes_dict_by_dtype.get(
state_dict_dtype,
self._default_ipc_tensor_bytes,
)
dtype_changed = (
self._last_update_params_ipc_tensor_dtype is not None
and state_dict_dtype != self._last_update_params_ipc_tensor_dtype
)
need_resize = state_dict_bytes > ipc_tensor_bytes
send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None

if update_params_ipc_tensor is not None:
self._update_params_ipc_event.wait()
if need_resize:
torch.cuda.synchronize()

if update_params_ipc_tensor is None or need_resize:
ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes)
self._ipc_tensor_bytes_dict_by_dtype[state_dict_dtype] = ipc_tensor_bytes
update_params_ipc_tensor = self._create_ipc_tensor(
ipc_tensor_bytes,
state_dict_dtype,
)
self._update_params_ipc_tensor_dict_by_dtype[state_dict_dtype] = update_params_ipc_tensor

flattened_tensor_bucket = flattened_tensor_bucket_cls(
named_tensors=list(state_dict.items()),
flattened_tensor=update_params_ipc_tensor,
)
flattened_tensor_data = {
"metadata": flattened_tensor_bucket.get_metadata(),
"require_clone": False,
}
self._update_params_ipc_event.record()
self._last_update_params_ipc_tensor_dtype = state_dict_dtype

if send_ipc_tensor:
flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor()
flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle()
return flattened_tensor_data

def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo:
engine_info: RolloutEngineInfo = []
seen_urls: set[str] = set()
Expand Down Expand Up @@ -744,36 +793,10 @@ def serialize_state_dict(state_dict: dict) -> str:
if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1:
serialized_data = [None] * self.rollout_cfg_info["tp"]
if use_flattened_tensor_bucket and state_dict:
state_dict_bytes = self._compute_state_dict_bytes(state_dict)
send_ipc_tensor = (
state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None
flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data(
state_dict,
FlattenedTensorBucket,
)
if send_ipc_tensor:
self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes)
if self._update_params_ipc_tensor is not None:
self._update_params_ipc_event.wait()
torch.cuda.synchronize()
self._update_params_ipc_tensor = self._create_ipc_tensor(
self._ipc_tensor_bytes,
state_dict[next(iter(state_dict))].dtype,
)
else:
self._update_params_ipc_event.wait()

flattened_tensor_bucket = FlattenedTensorBucket(
named_tensors=list(state_dict.items()),
flattened_tensor=self._update_params_ipc_tensor,
)
metadata = flattened_tensor_bucket.get_metadata()
flattened_tensor_data = {
"metadata": metadata,
"require_clone": False,
}
self._update_params_ipc_event.record()

if send_ipc_tensor:
flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor()
flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle()
tp_serialized_data = serialize_state_dict(flattened_tensor_data)
else:
tp_serialized_data = serialize_state_dict(state_dict)
Expand All @@ -785,37 +808,10 @@ def serialize_state_dict(state_dict: dict) -> str:
)
elif self.rollout_cfg_info["backend"] == "pytorch":
if use_flattened_tensor_bucket and state_dict:
state_dict_bytes = self._compute_state_dict_bytes(state_dict)
send_ipc_tensor = (
state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None
)
if send_ipc_tensor:
self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes)
if self._update_params_ipc_tensor is not None:
# wait previous ipc event recorded of lmdeploy
self._update_params_ipc_event.wait()
torch.cuda.synchronize()
self._update_params_ipc_tensor = self._create_ipc_tensor(
self._ipc_tensor_bytes,
state_dict[next(iter(state_dict))].dtype,
)
else:
self._update_params_ipc_event.wait()

flattened_tensor_bucket = FlattenedTensorBucket(
named_tensors=list(state_dict.items()),
flattened_tensor=self._update_params_ipc_tensor,
flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data(
state_dict,
FlattenedTensorBucket,
)
metadata = flattened_tensor_bucket.get_metadata()
flattened_tensor_data = {
"metadata": metadata,
"require_clone": False,
}
self._update_params_ipc_event.record()

if send_ipc_tensor:
flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor()
flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle()
serialized_data = serialize_state_dict(flattened_tensor_data)
else:
serialized_data = serialize_state_dict(state_dict)
Expand Down
Loading