Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xtuner/v1/rl/agent_loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
RouterAgentLoop,
get_agent_loop_rollout_ctl,
)
from .localhost_agent_loop.agent_in_localhost_loop import AgentInLocalhostLoop, AgentInLocalhostLoopConfig
from .sandbox_agent_loop.agent_in_sandbox_loop import AgentInSandboxLoop, AgentInSandboxLoopConfig
from .single_turn_agent_loop import SingleTurnAgentLoop, SingleTurnAgentLoopConfig


__all__ = [
"AgentInLocalhostLoop",
"AgentInLocalhostLoopConfig",
"AgentInSandboxLoop",
"AgentInSandboxLoopConfig",
"AgentLoopConfig",
Expand Down
34 changes: 34 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Public surface for the localhost_agent_loop runner."""

from xtuner.v1.rl.agent_loop.localhost_agent_loop.agent_in_localhost_loop import (
AgentInLocalhostLoop,
AgentInLocalhostLoopConfig,
)
from xtuner.v1.rl.agent_loop.localhost_agent_loop.compose import LocalhostComposeStage
from xtuner.v1.rl.agent_loop.localhost_agent_loop.judger import LocalhostJudgerStage
from xtuner.v1.rl.agent_loop.localhost_agent_loop.runner import LocalhostRunner
from xtuner.v1.rl.agent_loop.localhost_agent_loop.schemas import LocalhostAgentSpec
from xtuner.v1.rl.agent_loop.localhost_agent_loop.stage import LocalhostStage
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
RolloutStatus,
StageRecord,
StageStatus,
)


__all__ = [
"AgentInLocalhostLoop",
"AgentInLocalhostLoopConfig",
"AgentRolloutItem",
"LocalhostAgentSpec",
"LocalhostComposeStage",
"LocalhostJudgerStage",
"LocalhostRunner",
"LocalhostStage",
"RolloutError",
"RolloutStatus",
"StageRecord",
"StageStatus",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import asyncio
import copy
import importlib
import traceback
import uuid
from typing import Any

from lagent.utils import create_object

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutStatus,
)
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.rollout.trace_store import get_store
from xtuner.v1.rl.utils import create_task

from ..agent_loop import AgentLoop, AgentLoopConfig


def _import_from_path(path: str) -> Any:
module_name, _, attr = path.rpartition(".")
if not module_name or not attr:
raise ValueError(f"Invalid import path: {path!r}. Expected 'module.attr'.")
module = importlib.import_module(module_name)
return getattr(module, attr)


def _resolve_runner(pipeline: Any) -> Any:
if isinstance(pipeline, str):
pipeline = _import_from_path(pipeline)
if isinstance(pipeline, dict):
return create_object(copy.deepcopy(pipeline))
return pipeline


class AgentInLocalhostLoopConfig(AgentLoopConfig):
"""Run a localhost agent runner from ``RolloutState.extra_fields``."""

max_concurrent_samples: int | None = None

def build_local(
self,
rollout_controller: RolloutController | None = None,
judger: Judger | None = None,
logger=None,
) -> AgentInLocalhostLoop:
return AgentInLocalhostLoop(
rollout_ctl=rollout_controller,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
max_concurrent_samples=self.max_concurrent_samples,
)


class AgentInLocalhostLoop(AgentLoop):
"""AgentLoop adapter for localhost_agent_loop runners."""

def __init__(
self,
rollout_ctl: RolloutController | None = None,
sample_params: SampleParams | None = None,
hf_checkpoint: str | None = None,
judger: Judger | None = None,
logger=None,
max_concurrent_samples: int | None = None,
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.max_concurrent_samples = max_concurrent_samples
self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个代码可以考虑放到基类里面去,因为是通用的。否则其他agent 也要写一遍

async def generate_one(state: RolloutState) -> RolloutState:
if self._sample_semaphore is None:
return await self.generate_sample(state, **kwargs)
async with self._sample_semaphore:
return await self.generate_sample(state, **kwargs)

tasks = []
for state in rollout_state:
state.sample_params = self.sample_params
tasks.append(create_task(generate_one(state)))
return await asyncio.gather(*tasks)

async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
try:
item = self._rollout_item(rollout_state)
if rollout_state.uid is None:
rollout_state.uid = uuid.uuid4().int
item.uid = rollout_state.uid
item.group_id = rollout_state.message_uid
result = await self._run_item(item)
await self._fill_rollout_state(rollout_state, result)
return rollout_state
except Exception as exc:
rollout_state.status = Status.FAILED
rollout_state.finish_reason = "error"
rollout_state.error_msg = f"{type(exc).__name__}: {exc}"
self.logger.error(f"[AgentInLocalhostLoop] failed: {exc}\n{traceback.format_exc()}")
return rollout_state

def _rollout_item(self, rollout_state: RolloutState) -> AgentRolloutItem:
raw_item = rollout_state.extra_fields["rollout_item"]
if isinstance(raw_item, AgentRolloutItem):
return raw_item.model_copy(deep=True)
return AgentRolloutItem.model_validate(raw_item).model_copy(deep=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有可能不是 AgentRolloutItem 吗?如果一定是,不用写这个代码


async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem:
runner = _resolve_runner(item.pipeline)
if runner is None:
raise ValueError("AgentRolloutItem.pipeline is required.")
return await runner.run(item)

async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None:
trace = item.artifacts.get("messages")
if not isinstance(trace, list) or not trace:
raise ValueError("Agent artifacts must contain at least one trainable messages trace.")
segment = trace[-1]
if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment:
raise ValueError("Agent messages trace segment must contain messages and tools.")
messages = segment["messages"]
if not isinstance(messages, list):
raise TypeError("Agent messages trace segment.messages must be a list.")
text = self.tokenizer.apply_chat_template(
messages,
tools=segment["tools"],
tokenize=False,
add_generation_prompt=False,
)
prompt_text = text[:-1] if text.endswith("\n") else text
data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text)

rollout_state.input_ids = data["input_ids"]
rollout_state.labels = data["labels"]
rollout_state.response_ids = [
token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100
]
rollout_state.logprobs = data["logprobs"]
rollout_state.routed_experts = data["routed_experts"]
rollout_state.response = str(item.artifacts.get("response") or "")
rollout_state.finish_reason = "stop" if item.status == RolloutStatus.COMPLETED else "error"
rollout_state.status = Status.COMPLETED if item.status == RolloutStatus.COMPLETED else Status.FAILED
rollout_state.reward = {"score": item.reward}
rollout_state.extra_fields["raw_prompt"] = prompt_text
rollout_state.extra_fields["agent_artifacts"] = item.artifacts
if item.error is not None:
rollout_state.error_msg = f"{item.error.stage}/{item.error.category}: {item.error.message}"


__all__ = ["AgentInLocalhostLoop", "AgentInLocalhostLoopConfig"]
72 changes: 72 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Composable localhost stages."""

from __future__ import annotations

import time
from typing import Any

from lagent.utils import create_object

from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
StageRecord,
StageStatus,
)


class LocalhostComposeStage:
"""Compose multiple local validation stages behind ``run(item, record) ->
float``."""

def __init__(
self,
stages: list[Any],
*,
name: str = "validate",
weight: float = 1.0,
):
if not stages:
raise ValueError("LocalhostComposeStage.stages is empty")
self.name = name
self.stages = [create_object(stage) for stage in stages]
self.weight = weight

async def run(self, item: AgentRolloutItem, record: StageRecord) -> float:
record.status = StageStatus.RUNNING
record.started_at = record.started_at or time.monotonic()
try:
weighted_score = 0.0
total_weight = 0.0
for stage in self.stages:
name = getattr(stage, "name", stage.__class__.__name__)
child_record = item.judgers.setdefault(name, StageRecord())
score = float(await stage.run(item, child_record))
stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0)
weighted_score += score * stage_weight
total_weight += stage_weight
record.score = weighted_score / total_weight if total_weight > 0 else 0.0
record.status = StageStatus.COMPLETED
return record.score
except Exception as exc:
record.status = StageStatus.FAILED
child_error = next(
(child.error for child in item.judgers.values() if child.error is not None),
None,
)
record.error = (
record.error
or child_error
or RolloutError(
stage=self.name,
category="validate_failed",
type=type(exc).__name__,
message=str(exc),
)
)
raise
finally:
record.finished_at = time.monotonic()


__all__ = ["LocalhostComposeStage"]
117 changes: 117 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Localhost judger stages."""

from __future__ import annotations

import time
from copy import deepcopy
from typing import Any

from lagent.utils import create_object

from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
StageRecord,
StageStatus,
)
from xtuner.v1.rl.judger.native import Judger


class LocalhostJudgerStage:
"""Run one local validation stage.

Public stage interface is ``run(item, record) -> float``. ``RolloutState``
is only the internal shape needed to reuse xtuner judgers.
"""

def __init__(
self,
*,
name: str,
judger_config: Any | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥既需要judger_config又需要judger?

judger: Any | None = None,
reward_key: str = "score",
weight: float = 1.0,
):
if judger_config is None:
judger_config = judger
if judger_config is None:
raise ValueError("LocalhostJudgerStage requires judger_config")
self.name = name
self._judger_config = judger_config
self._judger: Judger | Any | None = None
self.reward_key = reward_key
self.weight = weight

async def run(self, item: AgentRolloutItem, record: StageRecord) -> float:
record.status = StageStatus.RUNNING
record.started_at = record.started_at or time.monotonic()
try:
reward_model = dict(item.reward_model or {})

messages_artifact = item.artifacts.get("messages")
if messages_artifact is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方允许是 none?

if not isinstance(messages_artifact, list) or not messages_artifact:
raise ValueError("Agent messages artifact must be a non-empty list.")
segment = messages_artifact[-1]
if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment:
raise ValueError("Agent messages trace segment must contain messages and tools.")
messages = segment["messages"]
if not isinstance(messages, list):
raise TypeError("Agent messages trace segment.messages must be a list.")
tool_turns = sum(
1
for message in messages
if isinstance(message, dict)
and isinstance(message.get("tool_calls"), list)
and message["tool_calls"]
)
reward_model.setdefault("agent_trace", messages)
reward_model.setdefault("num_turns", tool_turns)

response = str(item.artifacts.get("response") or "")
rollout_state = RolloutState(
message=[{"role": "user", "content": item.instruction}],
response=response,
reward_model=reward_model,
status=Status.COMPLETED if item.infer.status == StageStatus.COMPLETED else Status.FAILED,
)
judged = await self.build().judge(rollout_state)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 judge 可能耗时,那么要额外加判断,不是完成状态不进 judge,可以节省时间

reward_payload = judged.reward or {}
if self.reward_key not in reward_payload:
raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}")
record.metadata["reward"] = reward_payload
record.score = float(reward_payload[self.reward_key])
record.status = StageStatus.COMPLETED
return record.score
except Exception as exc:
record.status = StageStatus.FAILED
record.error = record.error or RolloutError(
stage=self.name,
category="judger",
type=type(exc).__name__,
message=str(exc),
)
raise
finally:
record.finished_at = time.monotonic()

def build(self) -> Judger | Any:
if self._judger is None:
if isinstance(self._judger_config, dict):
config = create_object(deepcopy(self._judger_config))
else:
config = self._judger_config
if hasattr(config, "build"):
self._judger = config.build()
elif hasattr(config, "judge"):
self._judger = config
else:
raise TypeError(
f"judger_config must build a Judger or be a Judger-like object, got {type(config).__name__}"
)
return self._judger


__all__ = ["LocalhostJudgerStage"]
Loading
Loading