Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SampleParams(BaseModel):
include_stop_str_in_output: bool = True
no_stop_trim: bool = True
spaces_between_special_tokens: bool = False
return_routed_experts: bool = False
enable_return_routed_experts: bool = True


class Status(Enum):
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/rl/gateway/adapters/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ChatCompletionRequest(BaseModel):
no_stop_trim: bool | None = None
seed: int | None = None
user: str | None = None
return_routed_experts: bool | None = None
enable_return_routed_experts: bool | None = None
chat_template_kwargs: dict[str, Any] | None = None


Expand Down Expand Up @@ -224,7 +224,7 @@ def request_to_canonical_request(self, request: ChatCompletionRequest) -> Canoni
"spaces_between_special_tokens": chat_template_kwargs.get("spaces_between_special_tokens"),
"sampling_seed": request.seed,
"user": request.user,
"return_routed_experts": request.return_routed_experts,
"enable_return_routed_experts": request.enable_return_routed_experts,
}.items()
if value is not None
},
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/gateway/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _build_sample_params(
"no_stop_trim": canonical_request.metadata.get("no_stop_trim"),
"spaces_between_special_tokens": canonical_request.metadata.get("spaces_between_special_tokens"),
"sampling_seed": canonical_request.metadata.get("sampling_seed"),
"return_routed_experts": canonical_request.metadata.get("return_routed_experts"),
"enable_return_routed_experts": canonical_request.metadata.get("enable_return_routed_experts"),
}.items()
if value is not None
},
Expand Down
17 changes: 7 additions & 10 deletions xtuner/v1/rl/rollout/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,7 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict:
text_prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
payload["input_ids"] = prompt_token_ids
lmdeploy_sample_params = self._transform_sample_params(
sample_params.model_copy(
update={
"return_routed_experts": (
self.enable_return_routed_experts and sample_params.return_routed_experts
)
}
)
)
lmdeploy_sample_params = self._transform_sample_params(sample_params)
payload.update(lmdeploy_sample_params)
else:
payload = {
Expand Down Expand Up @@ -383,4 +375,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
)

def _transform_sample_params(self, sample_params: SampleParams) -> dict:
return sample_params.model_dump(exclude_none=True)
sample_params_dict = sample_params.model_dump(exclude_none=True)
sample_params_dict.pop("enable_return_routed_experts", None)
sample_params_dict["return_routed_experts"] = (
self.enable_return_routed_experts and sample_params.enable_return_routed_experts
)
return sample_params_dict
8 changes: 4 additions & 4 deletions xtuner/v1/rl/rollout/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict:

if (
self.enable_return_routed_experts
and sample_params.return_routed_experts
and sample_params.enable_return_routed_experts
and not rollout_state.extra_fields.get("disable_routed_experts", False)
):
payload["return_routed_experts"] = True
payload["enable_return_routed_experts"] = True

if sample_params.return_token_ids:
if "image_data" in rollout_state.extra_fields:
Expand Down Expand Up @@ -106,10 +106,10 @@ async def _create_request(
sglang_extra_params = self._transform_extra_params(extra_params)
if (
self.enable_return_routed_experts
and sample_params.get("return_routed_experts", False)
and sample_params.get("enable_return_routed_experts", False)
and not extra_params.get("disable_routed_experts", False)
):
sglang_extra_params["return_routed_experts"] = True
sglang_extra_params["enable_return_routed_experts"] = True

payload.update(sglang_extra_params)

Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/rl/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ray import ObjectRef as RayObjectRef

from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.utils import asyncio_run, free_object_refs
from xtuner.v1.rl.utils import free_object_refs
from xtuner.v1.utils import get_logger


Expand Down Expand Up @@ -179,7 +179,7 @@ def run_once(self) -> None:
async def _run_checks() -> list[bool]:
return await asyncio.gather(*tasks)

check_results = asyncio_run(_run_checks())
check_results = asyncio.run(_run_checks())
inactive_workers = []
for (rank, _, _, _), is_healthy in zip(workers_to_check, check_results):
if not is_healthy:
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/rl/rollout/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ async def _create_request(
payload["input_ids"] = extra_info["train_prompt_ids"]

vllm_sample_params = self._transform_sample_params(sample_params, extra_params)
vllm_sample_params["return_routed_experts"] = self.enable_return_routed_experts and sample_params.get(
"return_routed_experts", False
vllm_sample_params["enable_return_routed_experts"] = self.enable_return_routed_experts and sample_params.get(
"enable_return_routed_experts", False
)
payload.update(vllm_sample_params)

Expand Down Expand Up @@ -379,7 +379,7 @@ async def _handle_non_stream_response(self, rollout_state: RolloutState, respons
last_token_ids: list[int] = []
last_logprobs: list[float] = []
routed_experts = None
should_return_routed_experts = self.enable_return_routed_experts and sample_params.return_routed_experts
should_return_routed_experts = self.enable_return_routed_experts and sample_params.enable_return_routed_experts

response_json = response.json()
response_choice = response_json["choices"][0]
Expand Down
4 changes: 3 additions & 1 deletion xtuner/v1/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,9 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response
logprobs: list[float] = []
routed_experts = None
returned_response = ""
should_return_routed_experts = self.enable_return_routed_experts and sample_params.return_routed_experts
should_return_routed_experts = (
self.enable_return_routed_experts and sample_params.enable_return_routed_experts
)
try:
meta_info = response.get("meta_info") or {}
finish_reason_info = meta_info.get("finish_reason") or {}
Expand Down
Loading