-
Notifications
You must be signed in to change notification settings - Fork 422
support localhost agent #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: agentic_branch
Are you sure you want to change the base?
support localhost agent #1842
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| """Public surface for the localhost_agent_loop runner.""" | ||
|
|
||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.agent_in_localhost_loop import ( | ||
| AgentInLocalhostLoop, | ||
| AgentInLocalhostLoopConfig, | ||
| ) | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.compose import LocalhostComposeStage | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.judger import LocalhostJudgerStage | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.runner import LocalhostRunner | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.schemas import LocalhostAgentSpec | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.stage import LocalhostStage | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| RolloutStatus, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "AgentInLocalhostLoop", | ||
| "AgentInLocalhostLoopConfig", | ||
| "AgentRolloutItem", | ||
| "LocalhostAgentSpec", | ||
| "LocalhostComposeStage", | ||
| "LocalhostJudgerStage", | ||
| "LocalhostRunner", | ||
| "LocalhostStage", | ||
| "RolloutError", | ||
| "RolloutStatus", | ||
| "StageRecord", | ||
| "StageStatus", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import copy | ||
| import importlib | ||
| import traceback | ||
| import uuid | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object | ||
|
|
||
| from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutStatus, | ||
| ) | ||
| from xtuner.v1.rl.judger import Judger | ||
| from xtuner.v1.rl.rollout import RolloutController | ||
| from xtuner.v1.rl.rollout.trace_store import get_store | ||
| from xtuner.v1.rl.utils import create_task | ||
|
|
||
| from ..agent_loop import AgentLoop, AgentLoopConfig | ||
|
|
||
|
|
||
| def _import_from_path(path: str) -> Any: | ||
| module_name, _, attr = path.rpartition(".") | ||
| if not module_name or not attr: | ||
| raise ValueError(f"Invalid import path: {path!r}. Expected 'module.attr'.") | ||
| module = importlib.import_module(module_name) | ||
| return getattr(module, attr) | ||
|
|
||
|
|
||
| def _resolve_runner(pipeline: Any) -> Any: | ||
| if isinstance(pipeline, str): | ||
| pipeline = _import_from_path(pipeline) | ||
| if isinstance(pipeline, dict): | ||
| return create_object(copy.deepcopy(pipeline)) | ||
| return pipeline | ||
|
|
||
|
|
||
| class AgentInLocalhostLoopConfig(AgentLoopConfig): | ||
| """Run a localhost agent runner from ``RolloutState.extra_fields``.""" | ||
|
|
||
| max_concurrent_samples: int | None = None | ||
|
|
||
| def build_local( | ||
| self, | ||
| rollout_controller: RolloutController | None = None, | ||
| judger: Judger | None = None, | ||
| logger=None, | ||
| ) -> AgentInLocalhostLoop: | ||
| return AgentInLocalhostLoop( | ||
| rollout_ctl=rollout_controller, | ||
| sample_params=self.sample_params, | ||
| hf_checkpoint=self.hf_checkpoint, | ||
| judger=judger, | ||
| logger=logger, | ||
| max_concurrent_samples=self.max_concurrent_samples, | ||
| ) | ||
|
|
||
|
|
||
| class AgentInLocalhostLoop(AgentLoop): | ||
| """AgentLoop adapter for localhost_agent_loop runners.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| rollout_ctl: RolloutController | None = None, | ||
| sample_params: SampleParams | None = None, | ||
| hf_checkpoint: str | None = None, | ||
| judger: Judger | None = None, | ||
| logger=None, | ||
| max_concurrent_samples: int | None = None, | ||
| ): | ||
| super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) | ||
| self.max_concurrent_samples = max_concurrent_samples | ||
| self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None | ||
|
|
||
| async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: | ||
| async def generate_one(state: RolloutState) -> RolloutState: | ||
| if self._sample_semaphore is None: | ||
| return await self.generate_sample(state, **kwargs) | ||
| async with self._sample_semaphore: | ||
| return await self.generate_sample(state, **kwargs) | ||
|
|
||
| tasks = [] | ||
| for state in rollout_state: | ||
| state.sample_params = self.sample_params | ||
| tasks.append(create_task(generate_one(state))) | ||
| return await asyncio.gather(*tasks) | ||
|
|
||
| async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: | ||
| try: | ||
| item = self._rollout_item(rollout_state) | ||
| if rollout_state.uid is None: | ||
| rollout_state.uid = uuid.uuid4().int | ||
| item.uid = rollout_state.uid | ||
| item.group_id = rollout_state.message_uid | ||
| result = await self._run_item(item) | ||
| await self._fill_rollout_state(rollout_state, result) | ||
| return rollout_state | ||
| except Exception as exc: | ||
| rollout_state.status = Status.FAILED | ||
| rollout_state.finish_reason = "error" | ||
| rollout_state.error_msg = f"{type(exc).__name__}: {exc}" | ||
| self.logger.error(f"[AgentInLocalhostLoop] failed: {exc}\n{traceback.format_exc()}") | ||
| return rollout_state | ||
|
|
||
| def _rollout_item(self, rollout_state: RolloutState) -> AgentRolloutItem: | ||
| raw_item = rollout_state.extra_fields["rollout_item"] | ||
| if isinstance(raw_item, AgentRolloutItem): | ||
| return raw_item.model_copy(deep=True) | ||
| return AgentRolloutItem.model_validate(raw_item).model_copy(deep=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有可能不是 AgentRolloutItem 吗?如果一定是,不用写这个代码 |
||
|
|
||
| async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem: | ||
| runner = _resolve_runner(item.pipeline) | ||
| if runner is None: | ||
| raise ValueError("AgentRolloutItem.pipeline is required.") | ||
| return await runner.run(item) | ||
|
|
||
| async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None: | ||
| trace = item.artifacts.get("messages") | ||
| if not isinstance(trace, list) or not trace: | ||
| raise ValueError("Agent artifacts must contain at least one trainable messages trace.") | ||
| segment = trace[-1] | ||
| if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment: | ||
| raise ValueError("Agent messages trace segment must contain messages and tools.") | ||
| messages = segment["messages"] | ||
| if not isinstance(messages, list): | ||
| raise TypeError("Agent messages trace segment.messages must be a list.") | ||
| text = self.tokenizer.apply_chat_template( | ||
| messages, | ||
| tools=segment["tools"], | ||
| tokenize=False, | ||
| add_generation_prompt=False, | ||
| ) | ||
| prompt_text = text[:-1] if text.endswith("\n") else text | ||
| data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) | ||
|
|
||
| rollout_state.input_ids = data["input_ids"] | ||
| rollout_state.labels = data["labels"] | ||
| rollout_state.response_ids = [ | ||
| token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 | ||
| ] | ||
| rollout_state.logprobs = data["logprobs"] | ||
| rollout_state.routed_experts = data["routed_experts"] | ||
| rollout_state.response = str(item.artifacts.get("response") or "") | ||
| rollout_state.finish_reason = "stop" if item.status == RolloutStatus.COMPLETED else "error" | ||
| rollout_state.status = Status.COMPLETED if item.status == RolloutStatus.COMPLETED else Status.FAILED | ||
| rollout_state.reward = {"score": item.reward} | ||
| rollout_state.extra_fields["raw_prompt"] = prompt_text | ||
| rollout_state.extra_fields["agent_artifacts"] = item.artifacts | ||
| if item.error is not None: | ||
| rollout_state.error_msg = f"{item.error.stage}/{item.error.category}: {item.error.message}" | ||
|
|
||
|
|
||
| __all__ = ["AgentInLocalhostLoop", "AgentInLocalhostLoopConfig"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """Composable localhost stages.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object | ||
|
|
||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
|
|
||
|
|
||
| class LocalhostComposeStage: | ||
| """Compose multiple local validation stages behind ``run(item, record) -> | ||
| float``.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| stages: list[Any], | ||
| *, | ||
| name: str = "validate", | ||
| weight: float = 1.0, | ||
| ): | ||
| if not stages: | ||
| raise ValueError("LocalhostComposeStage.stages is empty") | ||
| self.name = name | ||
| self.stages = [create_object(stage) for stage in stages] | ||
| self.weight = weight | ||
|
|
||
| async def run(self, item: AgentRolloutItem, record: StageRecord) -> float: | ||
| record.status = StageStatus.RUNNING | ||
| record.started_at = record.started_at or time.monotonic() | ||
| try: | ||
| weighted_score = 0.0 | ||
| total_weight = 0.0 | ||
| for stage in self.stages: | ||
| name = getattr(stage, "name", stage.__class__.__name__) | ||
| child_record = item.judgers.setdefault(name, StageRecord()) | ||
| score = float(await stage.run(item, child_record)) | ||
| stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0) | ||
| weighted_score += score * stage_weight | ||
| total_weight += stage_weight | ||
| record.score = weighted_score / total_weight if total_weight > 0 else 0.0 | ||
| record.status = StageStatus.COMPLETED | ||
| return record.score | ||
| except Exception as exc: | ||
| record.status = StageStatus.FAILED | ||
| child_error = next( | ||
| (child.error for child in item.judgers.values() if child.error is not None), | ||
| None, | ||
| ) | ||
| record.error = ( | ||
| record.error | ||
| or child_error | ||
| or RolloutError( | ||
| stage=self.name, | ||
| category="validate_failed", | ||
| type=type(exc).__name__, | ||
| message=str(exc), | ||
| ) | ||
| ) | ||
| raise | ||
| finally: | ||
| record.finished_at = time.monotonic() | ||
|
|
||
|
|
||
| __all__ = ["LocalhostComposeStage"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| """Localhost judger stages.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from copy import deepcopy | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object | ||
|
|
||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
| from xtuner.v1.rl.judger.native import Judger | ||
|
|
||
|
|
||
| class LocalhostJudgerStage: | ||
| """Run one local validation stage. | ||
|
|
||
| Public stage interface is ``run(item, record) -> float``. ``RolloutState`` | ||
| is only the internal shape needed to reuse xtuner judgers. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| name: str, | ||
| judger_config: Any | None = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥既需要judger_config又需要judger? |
||
| judger: Any | None = None, | ||
| reward_key: str = "score", | ||
| weight: float = 1.0, | ||
| ): | ||
| if judger_config is None: | ||
| judger_config = judger | ||
| if judger_config is None: | ||
| raise ValueError("LocalhostJudgerStage requires judger_config") | ||
| self.name = name | ||
| self._judger_config = judger_config | ||
| self._judger: Judger | Any | None = None | ||
| self.reward_key = reward_key | ||
| self.weight = weight | ||
|
|
||
| async def run(self, item: AgentRolloutItem, record: StageRecord) -> float: | ||
| record.status = StageStatus.RUNNING | ||
| record.started_at = record.started_at or time.monotonic() | ||
| try: | ||
| reward_model = dict(item.reward_model or {}) | ||
|
|
||
| messages_artifact = item.artifacts.get("messages") | ||
| if messages_artifact is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方允许是 none? |
||
| if not isinstance(messages_artifact, list) or not messages_artifact: | ||
| raise ValueError("Agent messages artifact must be a non-empty list.") | ||
| segment = messages_artifact[-1] | ||
| if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment: | ||
| raise ValueError("Agent messages trace segment must contain messages and tools.") | ||
| messages = segment["messages"] | ||
| if not isinstance(messages, list): | ||
| raise TypeError("Agent messages trace segment.messages must be a list.") | ||
| tool_turns = sum( | ||
| 1 | ||
| for message in messages | ||
| if isinstance(message, dict) | ||
| and isinstance(message.get("tool_calls"), list) | ||
| and message["tool_calls"] | ||
| ) | ||
| reward_model.setdefault("agent_trace", messages) | ||
| reward_model.setdefault("num_turns", tool_turns) | ||
|
|
||
| response = str(item.artifacts.get("response") or "") | ||
| rollout_state = RolloutState( | ||
| message=[{"role": "user", "content": item.instruction}], | ||
| response=response, | ||
| reward_model=reward_model, | ||
| status=Status.COMPLETED if item.infer.status == StageStatus.COMPLETED else Status.FAILED, | ||
| ) | ||
| judged = await self.build().judge(rollout_state) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果 judge 可能耗时,那么要额外加判断,不是完成状态不进 judge,可以节省时间 |
||
| reward_payload = judged.reward or {} | ||
| if self.reward_key not in reward_payload: | ||
| raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}") | ||
| record.metadata["reward"] = reward_payload | ||
| record.score = float(reward_payload[self.reward_key]) | ||
| record.status = StageStatus.COMPLETED | ||
| return record.score | ||
| except Exception as exc: | ||
| record.status = StageStatus.FAILED | ||
| record.error = record.error or RolloutError( | ||
| stage=self.name, | ||
| category="judger", | ||
| type=type(exc).__name__, | ||
| message=str(exc), | ||
| ) | ||
| raise | ||
| finally: | ||
| record.finished_at = time.monotonic() | ||
|
|
||
| def build(self) -> Judger | Any: | ||
| if self._judger is None: | ||
| if isinstance(self._judger_config, dict): | ||
| config = create_object(deepcopy(self._judger_config)) | ||
| else: | ||
| config = self._judger_config | ||
| if hasattr(config, "build"): | ||
| self._judger = config.build() | ||
| elif hasattr(config, "judge"): | ||
| self._judger = config | ||
| else: | ||
| raise TypeError( | ||
| f"judger_config must build a Judger or be a Judger-like object, got {type(config).__name__}" | ||
| ) | ||
| return self._judger | ||
|
|
||
|
|
||
| __all__ = ["LocalhostJudgerStage"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个代码可以考虑放到基类里面去,因为是通用的。否则其他agent 也要写一遍