diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py index 6b72053d4..c8929148a 100644 --- a/xtuner/v1/rl/trainer/update_weighter.py +++ b/xtuner/v1/rl/trainer/update_weighter.py @@ -62,8 +62,10 @@ def _init_update_weighter(self): self.worker_server_urls_status: dict[str, bool] = {} self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() - self._ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - self._update_params_ipc_tensor = None + self._default_ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + self._ipc_tensor_bytes_dict_by_dtype: dict[torch.dtype, int] = {} + self._update_params_ipc_tensor_dict_by_dtype: dict[torch.dtype, torch.Tensor] = {} + self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None self._update_params_ipc_event = None self._sglang_disagg_group: dist.ProcessGroup | None = None self._sglang_disagg_group_name: str | None = None @@ -187,7 +189,8 @@ def _update_weights_colocated(self): self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) else: self._update_weights_hf_generator(final_update=True) - self._update_params_ipc_tensor = None + self._update_params_ipc_tensor_dict_by_dtype = {} + self._last_update_params_ipc_tensor_dtype = None self._update_params_ipc_event = None DEVICE_MODULE.empty_cache() @@ -508,6 +511,52 @@ def _init_external_process_group( def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) + def _build_lmdeploy_flattened_tensor_data(self, state_dict: dict, flattened_tensor_bucket_cls) -> dict: + # LMDeploy flattened buckets require all tensors in one bucket to share a dtype. + state_dict_dtype = state_dict[next(iter(state_dict))].dtype + update_params_ipc_tensor = self._update_params_ipc_tensor_dict_by_dtype.get(state_dict_dtype, None) + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + ipc_tensor_bytes = self._ipc_tensor_bytes_dict_by_dtype.get( + state_dict_dtype, + self._default_ipc_tensor_bytes, + ) + dtype_changed = ( + self._last_update_params_ipc_tensor_dtype is not None + and state_dict_dtype != self._last_update_params_ipc_tensor_dtype + ) + need_resize = state_dict_bytes > ipc_tensor_bytes + send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None + + if update_params_ipc_tensor is not None: + self._update_params_ipc_event.wait() + if need_resize: + torch.cuda.synchronize() + + if update_params_ipc_tensor is None or need_resize: + ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes) + self._ipc_tensor_bytes_dict_by_dtype[state_dict_dtype] = ipc_tensor_bytes + update_params_ipc_tensor = self._create_ipc_tensor( + ipc_tensor_bytes, + state_dict_dtype, + ) + self._update_params_ipc_tensor_dict_by_dtype[state_dict_dtype] = update_params_ipc_tensor + + flattened_tensor_bucket = flattened_tensor_bucket_cls( + named_tensors=list(state_dict.items()), + flattened_tensor=update_params_ipc_tensor, + ) + flattened_tensor_data = { + "metadata": flattened_tensor_bucket.get_metadata(), + "require_clone": False, + } + self._update_params_ipc_event.record() + self._last_update_params_ipc_tensor_dtype = state_dict_dtype + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + return flattened_tensor_data + def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo: engine_info: RolloutEngineInfo = [] seen_urls: set[str] = set() @@ -744,36 +793,10 @@ def serialize_state_dict(state_dict: dict) -> str: if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: serialized_data = [None] * self.rollout_cfg_info["tp"] if use_flattened_tensor_bucket and state_dict: - state_dict_bytes = self._compute_state_dict_bytes(state_dict) - send_ipc_tensor = ( - state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( + state_dict, + FlattenedTensorBucket, ) - if send_ipc_tensor: - self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) - if self._update_params_ipc_tensor is not None: - self._update_params_ipc_event.wait() - torch.cuda.synchronize() - self._update_params_ipc_tensor = self._create_ipc_tensor( - self._ipc_tensor_bytes, - state_dict[next(iter(state_dict))].dtype, - ) - else: - self._update_params_ipc_event.wait() - - flattened_tensor_bucket = FlattenedTensorBucket( - named_tensors=list(state_dict.items()), - flattened_tensor=self._update_params_ipc_tensor, - ) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "metadata": metadata, - "require_clone": False, - } - self._update_params_ipc_event.record() - - if send_ipc_tensor: - flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() - flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() tp_serialized_data = serialize_state_dict(flattened_tensor_data) else: tp_serialized_data = serialize_state_dict(state_dict) @@ -785,37 +808,10 @@ def serialize_state_dict(state_dict: dict) -> str: ) elif self.rollout_cfg_info["backend"] == "pytorch": if use_flattened_tensor_bucket and state_dict: - state_dict_bytes = self._compute_state_dict_bytes(state_dict) - send_ipc_tensor = ( - state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None - ) - if send_ipc_tensor: - self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) - if self._update_params_ipc_tensor is not None: - # wait previous ipc event recorded of lmdeploy - self._update_params_ipc_event.wait() - torch.cuda.synchronize() - self._update_params_ipc_tensor = self._create_ipc_tensor( - self._ipc_tensor_bytes, - state_dict[next(iter(state_dict))].dtype, - ) - else: - self._update_params_ipc_event.wait() - - flattened_tensor_bucket = FlattenedTensorBucket( - named_tensors=list(state_dict.items()), - flattened_tensor=self._update_params_ipc_tensor, + flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( + state_dict, + FlattenedTensorBucket, ) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "metadata": metadata, - "require_clone": False, - } - self._update_params_ipc_event.record() - - if send_ipc_tensor: - flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() - flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() serialized_data = serialize_state_dict(flattened_tensor_data) else: serialized_data = serialize_state_dict(state_dict)