diff --git a/.gitignore b/.gitignore index 7b64dc3..db79fdf 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,5 @@ swarmexp swarmlog werewolves_swarm .claude +tensorboard_log +tutorial/**/*.json diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py index dc06c21..2e44974 100644 --- a/ajet/backbone/main_trinity.py +++ b/ajet/backbone/main_trinity.py @@ -53,7 +53,7 @@ def patched_trainer_get_actor(cls, config: Config): Trainer.get_actor = classmethod(patched_trainer_get_actor) if ajet_config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(ajet_config) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 8eebb95..0fe845c 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -251,7 +251,7 @@ def run(self, config): from ajet.backbone.trainer_verl import AjetRayPPOTrainer if config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(config) # Initialize the PPO trainer. diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 4e8b717..bb79eb5 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -187,7 +187,7 @@ def main(config): # atexit.register(lambda: print("Process exiting, performing cleanup...")) if config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(config) if config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 28f09f9..00caaa6 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -859,7 +859,7 @@ def fit(self): # noqa: C901 # # when enabled oai request interchange, we need to clear the cache from time to time # if self.config.ajet.enable_interchange_server: - # from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear + # from ajet.tuner_lib.experimental.oai_model_server import ensure_dat_interchange_server_cache_clear # ensure_dat_interchange_server_cache_clear() if is_last_step: diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index dcaca7b..c4ff3ee 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -603,11 +603,13 @@ def check_context_token_num_safe( add_generation_prompt=True, tokenize=False, ) - length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore - max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn + prompt_token_length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore + max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn max_model_len: int = self.config.ajet.rollout.max_model_len - max_seq_length: int = max_model_len - max_response_length - if length < max_seq_length: + max_seq_length: int = max_model_len - max_response_length_in_one_turn + # prompt_token_length: the prompt_token_length of current all previous context + # max_seq_length: max_model_len - max_response_length_in_one_turn + if prompt_token_length < max_seq_length: token_overflow = False else: token_overflow = True @@ -615,12 +617,13 @@ def check_context_token_num_safe( ret = (False, token_overflow, "externally_interrupted") elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination: ret = (False, token_overflow, "already_mad") - elif length < max_seq_length: + elif prompt_token_length < max_seq_length: ret = ( True, token_overflow, - f"safe[{length} < {max_model_len} - {max_response_length}]", + f"safe[{prompt_token_length} < {max_model_len} - {max_response_length_in_one_turn}]", ) else: - ret = (False, token_overflow, "token_overflow") + ret = (False, token_overflow, + f"token_overflow(prompt_token_length.{prompt_token_length}>=max_model_len.{max_model_len}-max_response_length_in_one_turn.{max_response_length_in_one_turn})") return ret diff --git a/ajet/context_tracker/single_agent_tracking.py b/ajet/context_tracker/single_agent_tracking.py index c49828a..775abf3 100644 --- a/ajet/context_tracker/single_agent_tracking.py +++ b/ajet/context_tracker/single_agent_tracking.py @@ -185,9 +185,9 @@ def compute_step_level_reward( def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List: result = [] for ext_msg in ext_msg_array: - d = { + d: dict = { "role": ext_msg.role, - "content": ext_msg.content_for_future, + "content": ext_msg.content_for_compare, } if ext_msg.tool_calls: d.update({"tool_calls": ext_msg.tool_calls}) diff --git a/ajet/context_tracker/timeline_merging/timeline_merging.py b/ajet/context_tracker/timeline_merging/timeline_merging.py index 86e7f8b..fcc3b05 100644 --- a/ajet/context_tracker/timeline_merging/timeline_merging.py +++ b/ajet/context_tracker/timeline_merging/timeline_merging.py @@ -21,8 +21,8 @@ def is_timeline_mergeable( for i in range(len(target_timeline)): if timeline_compare_level == "text": same = ( - source_timeline[i].content_for_future - == target_timeline[i].content_for_future + source_timeline[i].content_for_compare + == target_timeline[i].content_for_compare ) elif timeline_compare_level == "token": same = source_timeline[i].token_arr == target_timeline[i].token_arr @@ -52,12 +52,12 @@ def is_timeline_mergeable( # all_msg_match = False # for i in range(len(target_timeline)): # d = {} - # d["source"] = source_timeline[i].content_for_future - # d["target"] = target_timeline[i].content_for_future + # d["source"] = source_timeline[i].content_for_compare + # d["target"] = target_timeline[i].content_for_compare # if timeline_compare_level == "text": # same = ( - # source_timeline[i].content_for_future - # == target_timeline[i].content_for_future + # source_timeline[i].content_for_compare + # == target_timeline[i].content_for_compare # ) # elif timeline_compare_level == "token": # same = source_timeline[i].token_arr == target_timeline[i].token_arr diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 96ae203..86c07ae 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -10,24 +10,15 @@ import os import time import yaml -import tempfile -from types import SimpleNamespace from typing import Any, Callable, Union, cast from loguru import logger from ajet.default_config.ajet_default import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, - prepare_experiment_config, read_ajet_hierarchical_config, ) from ajet.utils.dynamic_import import cls_to_path -from ajet.utils.launch_utils import ( - execute_training_process, - check_avail_gpu, - get_backbone_target, - setup_environment_vars, -) def override_current_yaml_value_if_given(override_value, current_value): @@ -48,10 +39,27 @@ def _get_nested_attr(obj, attr_path: str): return obj class AgentJetJob: - """ - arg: base_yaml_config + **kwargs (yaml config, then override with kwargs) - arg: base_yaml_config (yaml config) - arg: **kwargs (yaml config, then override with kwargs) + """Programmatic interface for configuring and launching AgentJet training jobs. + + Args: + base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_ts_default.yaml). + experiment_dir: Directory where experiment outputs will be saved. + project_name: Name of the project for organizing experiments. + experiment_name: Unique name for this specific experiment run. + logging: "swanlab", "tensorboard", etc + n_gpu: Number of GPUs to use per node for training. + model: Path or identifier of the model to train. + algorithm: Advantage estimator algorithm (e.g., 'gae', 'vtrace'). + num_repeat: Tell swarm server how many repeated sample it should expect for a same task (same means task_id is identical). + batch_size: Training batch size for the model (the watermark to empty buffer pool and update llm weight). + swarm_mode: Whether to enable swarm mode for distributed sample collection. + swarm_mode_sample_collection_method: Method for collecting samples in swarm mode. + max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined). + backbone: Training backbone framework (e.g., 'verl'). + max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token). + max_response_length: Maximum token length for model responses (token length after the first llm-generated token). + max_model_len: Maximum total token length (prompt + response) the model can handle (bigger => more GPU memory). + mini_batch_num: Number of mini-batches to split training batch into (how many mini steps, i.e. how many times the `optimizer.step` should be executed, per big train batch). """ def __init__( @@ -60,6 +68,7 @@ def __init__( experiment_dir: str | None = None, project_name: str | None = None, experiment_name: str | None = None, + logging: str | None = None, n_gpu: int | None = None, model: str | None = None, algorithm: str | None = None, @@ -69,6 +78,11 @@ def __init__( swarm_mode_sample_collection_method: str | None = None, max_env_worker: int | None = None, backbone: str | None = None, + max_prompt_length: int | None = None, + max_response_length: int | None = None, + max_response_length_in_one_turn: int | None = None, + max_model_len: int | None = None, + mini_batch_num: int | None = None, ) -> None: if base_yaml_config is None: @@ -76,6 +90,14 @@ def __init__( else: logger.warning(f"Reading config from {base_yaml_config}.") time.sleep(1) + if not os.path.exists(base_yaml_config): + raise ValueError(f"Configuration yaml is absent! {base_yaml_config}") + + # Validate: max_prompt_length, max_response_length, max_model_len must all be None or all be non-None + length_params = [max_prompt_length, max_response_length, max_model_len, max_response_length_in_one_turn] + if not (all(p is None for p in length_params) or all(p is not None for p in length_params)): + raise ValueError("(`max_prompt_length`, `max_response_length`, `max_model_len`, `max_response_length_in_one_turn`) must all be None or all be non-None") + self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config) self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) @@ -83,6 +105,7 @@ def __init__( self.experiment_dir: str = cast(str, experiment_dir) self.project_name: str = cast(str, project_name) self.experiment_name: str = cast(str, experiment_name) + self.logging: str = cast(str, logging) self.n_gpu: int = cast(int, n_gpu) self.model: str = cast(str, model) self.algorithm: str = cast(str, algorithm) @@ -92,12 +115,19 @@ def __init__( self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method) self.max_env_worker: int = cast(int, max_env_worker) self.backbone: str = cast(str, backbone) + self.max_prompt_length: int = cast(int, max_prompt_length) + self.max_response_length_in_one_turn: int = cast(int, max_response_length_in_one_turn) + self.max_response_length: int = cast(int, max_response_length) + self.max_model_len: int = cast(int, max_model_len) + self.mini_batch_num: int = cast(int, mini_batch_num) # see `ajet/default_config/ajet_ts_default.yaml` overrides = { + # left: [yaml key navigation] right: [AgentJetJob self attr] "ajet.experiment_dir": "experiment_dir", "ajet.project_name": "project_name", "ajet.experiment_name": "experiment_name", + "ajet.trainer_common.logger": "logging", "ajet.model.path": "model", "ajet.trainer_common.n_gpus_per_node": "n_gpu", "ajet.trainer_common.algorithm.adv_estimator": "algorithm", @@ -107,6 +137,11 @@ def __init__( "ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method", "ajet.rollout.max_env_worker": "max_env_worker", "ajet.backbone": "backbone", + "ajet.data.max_prompt_length": "max_prompt_length", + "ajet.data.max_response_length": "max_response_length", + "ajet.rollout.max_response_length_in_one_turn": "max_response_length_in_one_turn", + "ajet.rollout.max_model_len": "max_model_len", + "ajet.trainer_common.mini_batch_num": "mini_batch_num", } # if any value given in kwargs, override the corresponding value in config @@ -127,6 +162,9 @@ def __init__( # >> e.g. self.model = new_model setattr(self, override_val, new_val) + + assert self.max_prompt_length + self.max_response_length <= self.max_model_len, "illegal token length" + assert self.max_response_length_in_one_turn <= self.max_response_length if self.backbone == "trinity": raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.") @@ -184,4 +222,3 @@ def set_data( ) return self - diff --git a/ajet/copilot/openjudge/SKILL.md b/ajet/copilot/openjudge/SKILL.md new file mode 100644 index 0000000..44e0eb8 --- /dev/null +++ b/ajet/copilot/openjudge/SKILL.md @@ -0,0 +1,159 @@ +--- +name: openjudge +description: > + Build custom LLM evaluation pipelines using the OpenJudge framework. + Covers selecting and configuring graders (LLM-based, function-based, agentic), + running batch evaluations with GradingRunner, combining scores with aggregators, + applying evaluation strategies (voting, average), auto-generating graders from + data, and analyzing results (pairwise win rates, statistics, validation metrics). + Use when the user wants to evaluate LLM outputs, compare multiple models, + design scoring criteria, or build an automated evaluation system. +--- + +# OpenJudge Skill + +Build evaluation pipelines for LLM applications using the `openjudge` library. + +## When to Use This Skill + +- User wants to evaluate LLM output quality (correctness, relevance, hallucination, etc.) +- User wants to compare two or more models and rank them +- User wants to design a scoring rubric and automate evaluation +- User wants to analyze evaluation results statistically +- User wants to build a reward model or quality filter + +## Sub-documents — Read When Relevant + +| Topic | File | Read when… | +|-------|------|------------| +| Grader selection & configuration | `graders.md` | User needs to pick or configure an evaluator | +| Batch evaluation pipeline | `pipeline.md` | User needs to run evaluation over a dataset | +| Auto-generate graders from data | `generator.md` | No rubric yet; generate from labeled examples | +| Analyze & compare results | `analyzer.md` | User wants win rates, statistics, or metrics | + +Read the relevant sub-document **before** writing any code. + +## Install + +```bash +pip install py-openjudge +``` + +## Architecture Overview + +``` +Dataset (List[dict]) + │ + ▼ +GradingRunner ← orchestrates everything + │ + ├─► Grader A ──► EvaluationStrategy ──► _aevaluate() ──► GraderScore / GraderRank + ├─► Grader B ──► EvaluationStrategy ──► _aevaluate() ──► GraderScore / GraderRank + └─► Grader C ... + │ + ├─► Aggregator (optional) ← combine multiple grader scores into one + │ + └─► RunnerResult ← {grader_name: [GraderScore, ...]} + │ + ▼ + Analyzer ← statistics, win rates, validation metrics +``` + +## 5-Minute Quick Start + +Evaluate responses for correctness using a built-in grader: + +```python +import asyncio +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.graders.common.correctness import CorrectnessGrader +from openjudge.runner.grading_runner import GradingRunner + +# 1. Configure the judge model (OpenAI-compatible endpoint) +model = OpenAIChatModel( + model="qwen-plus", + api_key="sk-xxx", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) + +# 2. Instantiate a grader +grader = CorrectnessGrader(model=model) + +# 3. Prepare dataset +dataset = [ + { + "query": "What is the capital of France?", + "response": "Paris is the capital of France.", + "reference_response": "Paris.", + }, + { + "query": "What is 2 + 2?", + "response": "The answer is five.", + "reference_response": "4.", + }, +] + +# 4. Run evaluation +async def main(): + runner = GradingRunner( + grader_configs={"correctness": grader}, + max_concurrency=8, + ) + results = await runner.arun(dataset) + + for i, result in enumerate(results["correctness"]): + print(f"[{i}] score={result.score} reason={result.reason}") + +asyncio.run(main()) +``` + +**Expected output:** +``` +[0] score=5 reason=The response accurately states Paris as capital... +[1] score=1 reason=The response gives the wrong answer (five vs 4)... +``` + +## Key Data Types + +| Type | Description | +|------|-------------| +| `GraderScore` | Pointwise result: `.score` (float), `.reason` (str), `.metadata` (dict) | +| `GraderRank` | Listwise result: `.rank` (List[int]), `.reason` (str), `.metadata` (dict) | +| `GraderError` | Error during evaluation: `.error` (str), `.reason` (str) | +| `RunnerResult` | `Dict[str, List[GraderResult]]` — keyed by grader name | + +## Result Handling Pattern + +```python +from openjudge.graders.schema import GraderScore, GraderRank, GraderError + +for grader_name, grader_results in results.items(): + for i, result in enumerate(grader_results): + if isinstance(result, GraderScore): + print(f"{grader_name}[{i}]: score={result.score}") + elif isinstance(result, GraderRank): + print(f"{grader_name}[{i}]: rank={result.rank}") + elif isinstance(result, GraderError): + print(f"{grader_name}[{i}]: ERROR — {result.error}") +``` + +## Model Configuration + +All LLM-based graders accept either a `BaseChatModel` instance or a dict config: + +```python +# Option A: instance +from openjudge.models.openai_chat_model import OpenAIChatModel +model = OpenAIChatModel(model="gpt-4o", api_key="sk-...") + +# Option B: dict (auto-creates OpenAIChatModel) +model_cfg = {"model": "gpt-4o", "api_key": "sk-..."} +grader = CorrectnessGrader(model=model_cfg) + +# OpenAI-compatible endpoints (DashScope / local / etc.) +model = OpenAIChatModel( + model="qwen-plus", + api_key="sk-xxx", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +``` diff --git a/ajet/copilot/openjudge/analyzer.md b/ajet/copilot/openjudge/analyzer.md new file mode 100644 index 0000000..edbe3e0 --- /dev/null +++ b/ajet/copilot/openjudge/analyzer.md @@ -0,0 +1,287 @@ +# Analyzer Reference + +Analyzers process `RunnerResult` to produce aggregated insights: +statistics, pairwise rankings, and validation metrics against ground truth. + +All analyzers follow the same interface: +```python +result = analyzer.analyze(dataset, grader_results, **kwargs) +``` + +--- + +## PairwiseAnalyzer — Model Comparison & Win Rates + +Use when evaluating multiple models head-to-head. +Computes win rates, a win matrix, and final rankings. + +### Setup + +Dataset samples must contain a `metadata` dict with `model_a` and `model_b` keys: + +```python +dataset = [ + {"metadata": {"model_a": "gpt-4o", "model_b": "qwen-max"}}, + {"metadata": {"model_a": "qwen-max", "model_b": "gpt-4o"}}, # swapped pair + ... +] +``` + +Grader results use score conventions: +- `score >= 0.5` → `model_a` wins +- `score < 0.5` → `model_b` wins + +### Example + +```python +from openjudge.analyzer.pairwise_analyzer import PairwiseAnalyzer +from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.schema import GraderMode +from openjudge.runner.grading_runner import GradingRunner + +# Build a pairwise judge grader +judge = LLMGrader( + model=model, + name="pairwise_judge", + mode=GraderMode.POINTWISE, + template=""" +You are a judge. Compare Response A and Response B for the given query. +Score 1.0 if Response A is better, 0.0 if Response B is better, 0.5 if tied. + +Query: {query} +Response A: {response_a} +Response B: {response_b} + +JSON: {{"score": , "reason": ""}} +""", +) + +# Dataset: pairwise samples (typically generated with position swap for bias correction) +dataset = [ + { + "query": "What is quantum computing?", + "response_a": "GPT-4o answer...", + "response_b": "Qwen-max answer...", + "metadata": {"model_a": "gpt-4o", "model_b": "qwen-max"}, + }, + { + "query": "What is quantum computing?", + "response_a": "Qwen-max answer...", + "response_b": "GPT-4o answer...", + "metadata": {"model_a": "qwen-max", "model_b": "gpt-4o"}, # swapped + }, +] + +runner = GradingRunner(grader_configs={"judge": judge}, max_concurrency=8) +results = await runner.arun(dataset) + +# Analyze +analyzer = PairwiseAnalyzer(model_names=["gpt-4o", "qwen-max"]) +analysis = analyzer.analyze(dataset, results["judge"]) + +print(f"Best model: {analysis.best_model}") +print(f"Rankings: {analysis.rankings}") +print(f"Win rates: {analysis.win_rates}") +print(f"Win matrix: {analysis.win_matrix}") +``` + +**Result fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `best_model` | str | Model with highest win rate | +| `worst_model` | str | Model with lowest win rate | +| `win_rates` | `Dict[str, float]` | Win rate per model (0.0–1.0) | +| `rankings` | `List[Tuple[str, float]]` | Sorted by win rate descending | +| `win_matrix` | `Dict[str, Dict[str, float]]` | `win_matrix[A][B]` = how often A beats B | +| `total_comparisons` | int | Total pairwise samples analyzed | + +--- + +## Statistical Analyzers + +### DistributionAnalyzer + +Computes score distribution statistics for a single grader's results. + +```python +from openjudge.analyzer.statistical.distribution_analyzer import DistributionAnalyzer + +analyzer = DistributionAnalyzer() +result = analyzer.analyze(dataset, results["correctness"]) + +print(f"mean={result.mean:.3f}") +print(f"median={result.median:.3f}") +print(f"stdev={result.stdev:.3f}") +print(f"min={result.min_score} max={result.max_score}") +``` + +**Result fields:** `mean`, `median`, `stdev`, `min_score`, `max_score` + +--- + +### ConsistencyAnalyzer + +Measures how consistent a grader is across two independent runs on the same samples. +Returns Pearson correlation between the two score lists. + +```python +from openjudge.analyzer.statistical.consistency_analyzer import ConsistencyAnalyzer + +# Run the same grader twice +runner = GradingRunner(grader_configs={"correctness": grader}, max_concurrency=8) +run1 = await runner.arun(dataset) +run2 = await runner.arun(dataset) + +analyzer = ConsistencyAnalyzer() +result = analyzer.analyze( + dataset=dataset, + grader_results=run1["correctness"], + another_grader_results=run2["correctness"], +) + +print(f"Consistency (Pearson r): {result.consistency:.4f}") +# 1.0 = perfectly consistent; 0.0 = no correlation +``` + +**Result fields:** `consistency` (float, Pearson r) + +--- + +## Validation Analyzers + +Validation analyzers compare grader scores against **ground truth labels** in the dataset. + +**Prerequisite:** Each sample in `dataset` must have a label field (default key: `"label"`). + +```python +dataset = [ + {"query": "...", "response": "...", "label": 1}, # ground truth: correct + {"query": "...", "response": "...", "label": 0}, # ground truth: incorrect +] +``` + +### AccuracyAnalyzer + +Fraction of samples where `grader.score == label`. + +```python +from openjudge.analyzer.validation import AccuracyAnalyzer + +analyzer = AccuracyAnalyzer() +result = analyzer.analyze(dataset, grader_results, label_path="label") +print(f"Accuracy: {result.accuracy:.2%}") +``` + +### F1ScoreAnalyzer + +Harmonic mean of precision and recall. + +```python +from openjudge.analyzer.validation import F1ScoreAnalyzer + +analyzer = F1ScoreAnalyzer() +result = analyzer.analyze(dataset, grader_results, label_path="label") +print(f"F1: {result.f1_score:.4f}") +``` + +### PrecisionAnalyzer / RecallAnalyzer + +```python +from openjudge.analyzer.validation import PrecisionAnalyzer, RecallAnalyzer + +precision_result = PrecisionAnalyzer().analyze(dataset, grader_results) +recall_result = RecallAnalyzer().analyze(dataset, grader_results) +print(f"Precision: {precision_result.precision:.4f}") +print(f"Recall: {recall_result.recall:.4f}") +``` + +### FalsePositiveAnalyzer / FalseNegativeAnalyzer + +```python +from openjudge.analyzer.validation import FalsePositiveAnalyzer, FalseNegativeAnalyzer + +fp_result = FalsePositiveAnalyzer().analyze(dataset, grader_results) +fn_result = FalseNegativeAnalyzer().analyze(dataset, grader_results) +print(f"False positive rate: {fp_result.false_positive_rate:.4f}") +print(f"False negative rate: {fn_result.false_negative_rate:.4f}") +``` + +### CorrelationAnalyzer + +Pearson/Spearman correlation between grader scores and numeric labels. + +```python +from openjudge.analyzer.validation import CorrelationAnalyzer + +analyzer = CorrelationAnalyzer() +result = analyzer.analyze(dataset, grader_results, label_path="score_label") +print(f"Pearson r: {result.pearson_correlation:.4f}") +print(f"Spearman r: {result.spearman_correlation:.4f}") +``` + +--- + +## All Validation Analyzers — Summary Table + +| Analyzer | Key result field | Use when | +|----------|-----------------|----------| +| `AccuracyAnalyzer` | `.accuracy` | Binary or categorical grader vs label | +| `F1ScoreAnalyzer` | `.f1_score` | Binary classification, imbalanced labels | +| `PrecisionAnalyzer` | `.precision` | Cost of false positives is high | +| `RecallAnalyzer` | `.recall` | Cost of false negatives is high | +| `FalsePositiveAnalyzer` | `.false_positive_rate` | Measure over-flagging | +| `FalseNegativeAnalyzer` | `.false_negative_rate` | Measure under-detection | +| `CorrelationAnalyzer` | `.pearson_correlation`, `.spearman_correlation` | Continuous score calibration | + +--- + +## Complete Analysis Workflow + +```python +import asyncio +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.graders.common.correctness import CorrectnessGrader +from openjudge.runner.grading_runner import GradingRunner +from openjudge.analyzer.statistical.distribution_analyzer import DistributionAnalyzer +from openjudge.analyzer.validation import AccuracyAnalyzer, F1ScoreAnalyzer + +model = OpenAIChatModel(model="qwen-plus", api_key="sk-xxx", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + +# Dataset with ground truth labels +dataset = [ + {"query": "2+2?", "response": "4", "reference_response": "4", "label": 1}, + {"query": "2+2?", "response": "Five", "reference_response": "4", "label": 0}, + {"query": "Capital of France?", "response": "Paris", "reference_response": "Paris", "label": 1}, + {"query": "Capital of France?", "response": "London", "reference_response": "Paris", "label": 0}, +] + +async def main(): + runner = GradingRunner( + grader_configs={"correctness": CorrectnessGrader(model=model)}, + max_concurrency=4, + ) + results = await runner.arun(dataset) + grader_results = results["correctness"] + + # Score distribution + dist = DistributionAnalyzer().analyze(dataset, grader_results) + print(f"Score distribution: mean={dist.mean:.2f}, stdev={dist.stdev:.2f}") + + # Validation against labels (binarize: score >= 3 → correct) + binary_results = [] + from openjudge.graders.schema import GraderScore + for r in grader_results: + if isinstance(r, GraderScore): + binary_results.append(GraderScore( + name=r.name, score=1.0 if r.score >= 3 else 0.0, reason=r.reason + )) + + acc = AccuracyAnalyzer().analyze(dataset, binary_results, label_path="label") + f1 = F1ScoreAnalyzer().analyze(dataset, binary_results, label_path="label") + print(f"Accuracy: {acc.accuracy:.2%}") + print(f"F1 Score: {f1.f1_score:.4f}") + +asyncio.run(main()) +``` diff --git a/ajet/copilot/openjudge/generator.md b/ajet/copilot/openjudge/generator.md new file mode 100644 index 0000000..d7c1360 --- /dev/null +++ b/ajet/copilot/openjudge/generator.md @@ -0,0 +1,252 @@ +# Generator Reference + +Generators automatically create `LLMGrader` instances by deriving evaluation rubrics +from data — no manual rubric writing required. + +**Use a generator when:** +- You have labeled examples (query + response + score/rank) but no rubric +- You want to adapt evaluation criteria to a specific task domain +- You need to bootstrap a grader from scratch + +--- + +## Two Generator Types + +| Generator | Input | Best for | +|-----------|-------|----------| +| `SimpleRubricsGenerator` | Task description + optional sample queries | Cold start, no labeled data needed | +| `IterativeRubricsGenerator` | Labeled dataset (query + response + score) | Better quality, learns from preference data | + +Both return a ready-to-use `LLMGrader`. + +--- + +## SimpleRubricsGenerator + +Generates rubrics from a **task description** and optional sample queries. +No labeled data required — fastest way to bootstrap a grader. + +### Config parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `grader_name` | str | `"Generated Grader"` | Name for the generated grader | +| `model` | BaseChatModel | required | LLM used to generate rubrics | +| `task_description` | str | `""` | What the task is about | +| `scenario` | str | None | Usage context (e.g., "customer support chatbot") | +| `grader_mode` | GraderMode | `POINTWISE` | `POINTWISE` or `LISTWISE` | +| `language` | LanguageEnum | `EN` | `EN` or `ZH` | +| `min_score` | int | `0` | Min score (pointwise mode) | +| `max_score` | int | `1` | Max score (pointwise mode) | + +### Example — pointwise grader from task description + +```python +import asyncio +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.generator.simple_rubric.generator import ( + SimpleRubricsGenerator, + SimpleRubricsGeneratorConfig, +) + +model = OpenAIChatModel(model="qwen-plus", api_key="sk-xxx", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + +config = SimpleRubricsGeneratorConfig( + grader_name="Customer Support Grader", + model=model, + task_description="Customer support chatbot for an e-commerce platform", + scenario="Customers asking about orders, returns, and shipping", + min_score=0, + max_score=1, +) + +generator = SimpleRubricsGenerator(config) + +async def main(): + # Option A: pass sample queries explicitly + grader = await generator.generate( + dataset=[], + sample_queries=[ + "Where is my order?", + "How do I return a product?", + "What is the shipping time?", + ], + ) + + # Option B: extract queries from dataset automatically (uses first 5) + dataset = [{"query": "Where is my order?", "response": "..."}] + grader = await generator.generate(dataset=dataset) + + # Use the generated grader + result = await grader.aevaluate( + query="How do I cancel my order?", + response="You can cancel your order within 24 hours from the order page.", + ) + print(f"score={result.score} reason={result.reason}") + +asyncio.run(main()) +``` + +### Example — listwise (ranking) grader + +```python +from openjudge.graders.schema import GraderMode + +config = SimpleRubricsGeneratorConfig( + grader_name="Response Ranker", + model=model, + task_description="Compare and rank responses to customer questions", + grader_mode=GraderMode.LISTWISE, +) +generator = SimpleRubricsGenerator(config) +grader = await generator.generate(dataset=[]) +``` + +--- + +## IterativeRubricsGenerator + +Derives rubrics from **labeled preference data** using an iterative Propose-Evaluate-Revise loop, +then selects an optimal non-redundant subset via information-theoretic MCR² selection. + +Based on the paper: *Auto-Rubric: Learning to Extract Generalizable Criteria for Reward Modeling* + +### Two config classes (choose based on mode) + +**Pointwise:** +```python +from openjudge.generator.iterative_rubric.generator import ( + IterativeRubricsGenerator, + IterativePointwiseRubricsGeneratorConfig, +) + +config = IterativePointwiseRubricsGeneratorConfig( + grader_name="My Pointwise Grader", + model=model, + min_score=0, + max_score=1, + # optional tuning: + task_description="Evaluate answers to science questions", + enable_categorization=False, + max_epochs=3, + batch_size=10, +) +``` + +**Listwise:** +```python +from openjudge.generator.iterative_rubric.generator import ( + IterativeRubricsGenerator, + IterativeListwiseRubricsGeneratorConfig, +) + +config = IterativeListwiseRubricsGeneratorConfig( + grader_name="My Listwise Grader", + model=model, +) +``` + +### Dataset format + +**Pointwise dataset** — each sample needs `query`, `response`, and optionally `label_score` (for validation): + +```python +pointwise_dataset = [ + {"query": "What causes rain?", "response": "Water vapour condenses...", "label_score": 1}, + {"query": "What is DNA?", "response": "DNA is a molecule...", "label_score": 1}, + {"query": "What is DNA?", "response": "I don't know.", "label_score": 0}, +] +``` + +**Listwise dataset** — each sample needs `query`, `responses` list, and optionally `label_rank` (for validation): + +```python +listwise_dataset = [ + { + "query": "Explain photosynthesis", + "responses": [ + "Plants use sunlight, CO₂, and water to produce glucose.", + "Plants need sunlight.", + ], + "label_rank": [1, 2], # 1 = best + }, +] +``` + +### Full example + +```python +import asyncio +from openjudge.generator.iterative_rubric.generator import ( + IterativeRubricsGenerator, + IterativePointwiseRubricsGeneratorConfig, +) + +config = IterativePointwiseRubricsGeneratorConfig( + grader_name="Science QA Grader", + model=model, + task_description="Evaluate factual answers to science questions", + min_score=0, + max_score=1, + max_epochs=3, + batch_size=5, +) + +generator = IterativeRubricsGenerator(config) + +async def main(): + train_data = [ + {"query": "What is gravity?", "response": "A force attracting masses.", "label_score": 1}, + {"query": "What is gravity?", "response": "Something heavy.", "label_score": 0}, + {"query": "What is entropy?", "response": "Measure of disorder.", "label_score": 1}, + {"query": "What is entropy?", "response": "A type of energy.", "label_score": 0}, + ] + + # Generate grader — may take several minutes for large datasets + grader = await generator.generate(dataset=train_data) + + # Evaluate new samples + result = await grader.aevaluate( + query="What is osmosis?", + response="Osmosis is the movement of water across a semi-permeable membrane.", + ) + print(f"score={result.score} reason={result.reason}") + +asyncio.run(main()) +``` + +### Key config parameters (IterativeRubricsGeneratorConfig) + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `enable_categorization` | `False` | Merge similar rubrics via LLM (slower, more organised) | +| `categories_number` | `5` | Target category count (only when categorization enabled) | +| `max_epochs` | `5` | Max Propose-Evaluate-Revise iterations per sample | +| `batch_size` | `10` | Samples per batch | +| `max_total_rubrics` | `200` | Cap on total rubrics collected | +| `min_increment_threshold` | `0.002` | Convergence threshold for MCR² selection | +| `patience` | `2` | Consecutive low-increment batches before early stop | + +**Sampling mode is auto-selected:** +- `≤ 100 samples` → all_samples mode (process all concurrently) +- `> 100 samples` → smart_sampling mode (MCR²-guided batch iteration) + +--- + +## Using a Generated Grader in GradingRunner + +The returned `LLMGrader` is a standard grader — plug it directly into a runner: + +```python +from openjudge.runner.grading_runner import GradingRunner + +grader = await generator.generate(dataset=train_data) + +runner = GradingRunner( + grader_configs={"auto_rubric": grader}, + max_concurrency=8, +) +test_dataset = [{"query": "...", "response": "..."}] +results = await runner.arun(test_dataset) +``` diff --git a/ajet/copilot/openjudge/graders.md b/ajet/copilot/openjudge/graders.md new file mode 100644 index 0000000..7725e25 --- /dev/null +++ b/ajet/copilot/openjudge/graders.md @@ -0,0 +1,381 @@ +# Graders Reference + +Graders are the core evaluation units in OpenJudge. +Every grader inherits from `BaseGrader` and implements `async _aevaluate(**kwargs)`. + +## Grader Types + +| Type | Class | Best for | +|------|-------|----------| +| LLM-based | `LLMGrader` | Subjective quality, semantic understanding | +| Function-based | `FunctionGrader` | Exact rules, fast deterministic checks | +| Agentic | `AgenticGrader` | Evaluation requiring tool calls (search, code run) | + +--- + +## Built-in Graders — Quick Reference + +### `common/` — General-purpose (all LLM-based, POINTWISE, score 1–5) + +| Class | Import | Key inputs | What it measures | +|-------|--------|------------|-----------------| +| `CorrectnessGrader` | `openjudge.graders.common.correctness` | `query`, `response`, `reference_response`, `context` | Factual match against reference | +| `HallucinationGrader` | `openjudge.graders.common.hallucination` | `query`, `response`, `context` | Fabricated/unsupported claims | +| `RelevanceGrader` | `openjudge.graders.common.relevance` | `query`, `response` | How relevant the response is | +| `HarmfulnessGrader` | `openjudge.graders.common.harmfulness` | `query`, `response` | Toxic or harmful content | +| `InstructionFollowingGrader` | `openjudge.graders.common.instruction_following` | `query`, `response` | Instruction compliance | +| `SearchCorrectnessGrader` | `openjudge.graders.common.search_correctness` | `query`, `response`, `context` | Correctness in RAG/search context | + +All `common/` graders accept `model` (required) and optional `threshold`, `language`, `strategy`. + +```python +from openjudge.graders.common.hallucination import HallucinationGrader + +grader = HallucinationGrader(model=model) +result = await grader.aevaluate( + query="Who invented the telephone?", + response="Thomas Edison invented the telephone in 1876.", + context="Alexander Graham Bell is credited with the telephone (1876).", +) +# result.score: 1–5 (5 = no hallucination, 1 = severe hallucination) +``` + +--- + +### `text/` — String & Text Matching (no LLM needed) + +| Class | Import | Key inputs | What it measures | +|-------|--------|------------|-----------------| +| `StringMatchGrader` | `openjudge.graders.text.string_match` | `response`, `reference_response` | Exact/regex/overlap matching | +| `SimilarityGrader` | `openjudge.graders.text.similarity` | `response`, `reference` | ROUGE / BM25 / embedding similarity | +| `NumberAccuracyGrader` | `openjudge.graders.text.number_accuracy` | `response`, `reference` | Numerical value accuracy | + +**StringMatchGrader algorithms:** `exact_match`, `prefix_match`, `suffix_match`, `regex_match`, +`substring_match`, `contains_all`, `contains_any`, `word_overlap`, `char_overlap` + +> **Important:** The algorithm must be set at **init time** via the `algorithm=` constructor +> argument. Passing `algorithm` in `aevaluate()` has **no effect** — the init value is always used. + +```python +from openjudge.graders.text.string_match import StringMatchGrader + +# Set algorithm at init time +grader = StringMatchGrader(algorithm="substring_match") +result = await grader.aevaluate( + response="The capital is Paris.", + reference_response="Paris", +) +# result.score: 1.0 (match) or 0.0 (no match) + +# Different algorithm — create a new grader instance +grader_overlap = StringMatchGrader(algorithm="word_overlap") +result2 = await grader_overlap.aevaluate( + response="The quick brown fox", + reference_response="quick fox", +) +# result2.score: overlap ratio (0.0–1.0) +``` + +--- + +### `code/` — Code Evaluation + +| Class | Import | Key inputs | What it measures | +|-------|--------|------------|-----------------| +| `CodeExecutionGrader` | `openjudge.graders.code.code_execution` | `response` | Test case pass rate (test cases from harness/metadata) | +| `SyntaxCheckGrader` | `openjudge.graders.code.syntax_checker` | `response` | Syntax validity | +| `CodeStyleGrader` | `openjudge.graders.code.code_style` | `response` | Style/lint quality | +| `PatchSimilarityGrader` | `openjudge.graders.code.patch_similarity` | `response`, `reference` | Patch/diff similarity | + +```python +from openjudge.graders.code.code_execution import CodeExecutionGrader + +grader = CodeExecutionGrader(timeout=10) +result = await grader.aevaluate(response="def add(a, b): return a + b") +# result.score: fraction of passed test cases (0.0–1.0). +# Test cases must be provided via sample metadata or external harness; see grader docs. +``` + +--- + +### `format/` — Output Format Validation + +| Class | Import | Key inputs | What it measures | +|-------|--------|------------|-----------------| +| `JsonValidatorGrader` | `openjudge.graders.format.json.json_validator` | `response` | Is response valid JSON? | +| `JsonMatchGrader` | `openjudge.graders.format.json.json_match` | `response`, `reference` | JSON structure/content match | +| `LengthPenaltyGrader` | `openjudge.graders.format.length_penalty` | `response` | Penalizes over/under-length | +| `NgramRepetitionPenaltyGrader` | `openjudge.graders.format.ngram_repetition_penalty` | `response` | Penalizes repeated n-grams | +| `ReasoningFormatGrader` | `openjudge.graders.format.reasoning_format` | `response` | `...` format check | + +```python +from openjudge.graders.format.json.json_validator import JsonValidatorGrader + +grader = JsonValidatorGrader() +result = await grader.aevaluate(response='{"key": "value"}') +# result.score: 1.0 (valid JSON) or 0.0 (invalid) +``` + +--- + +### `math/` — Mathematical Expressions + +| Class | Import | Key inputs | What it measures | +|-------|--------|------------|-----------------| +| `MathExpressionVerifyGrader` | `openjudge.graders.math.math_expression_verify` | `response`, `reference` | Mathematical equivalence | + +--- + +### `agent/` — Agent Behavior Evaluation (all LLM-based) + +| Category | Class | What it measures | +|----------|-------|-----------------| +| **Tool** | `ToolCallAccuracyGrader` | Whether tool calls are correct | +| **Tool** | `ToolCallSuccessGrader` | Whether tool calls succeeded | +| **Tool** | `ToolSelectionGrader` | Whether the right tool was chosen | +| **Tool** | `ToolParameterCheckGrader` | Correctness of tool parameters | +| **Tool** | `ToolCallStepSequenceMatchGrader` | Tool call order vs expected | +| **Tool** | `ToolCallPrecisionRecallMatchGrader` | Precision/recall of tool call set | +| **Memory** | `MemoryAccuracyGrader` | Accuracy of stored memory | +| **Memory** | `MemoryDetailPreservationGrader` | Detail retention in memory | +| **Memory** | `MemoryRetrievalEffectivenessGrader` | Quality of memory retrieval | +| **Plan** | `PlanFeasibilityGrader` | Whether the plan is feasible | +| **Reflection** | `ReflectionAccuracyGrader` | Accuracy of self-reflection | +| **Action** | `ActionAlignmentGrader` | Action alignment with intent | +| **Trajectory** | `TrajectoryAccuracyGrader` | Trajectory vs reference | +| **Trajectory** | `TrajectoryComprehensiveGrader` | End-to-end trajectory quality | + +```python +from openjudge.graders.agent import ToolCallAccuracyGrader + +grader = ToolCallAccuracyGrader(model=model) +result = await grader.aevaluate( + query="Search for today's weather", + tool_definitions=[{"name": "web_search", "description": "Search the web", "parameters": {}}], + tool_calls=[{"name": "web_search", "arguments": {"query": "today weather"}}], +) +# result.score: 1–5 (tool call accuracy) +``` + +--- + +### `multi_turn/` — Multi-turn Conversation (all LLM-based) + +| Class | What it measures | +|-------|-----------------| +| `ContextMemoryGrader` | Recalls details from early turns | +| `AnaphoraResolutionGrader` | Pronoun/reference resolution | +| `TopicSwitchGrader` | Handles sudden topic changes | +| `SelfCorrectionGrader` | Corrects errors when given feedback | +| `InstructionClarificationGrader` | Asks for clarification when needed | +| `ProactiveInteractionGrader` | Proactively engages in conversation | +| `ResponseRepetitionGrader` | Avoids repeating prior content | + +```python +from openjudge.graders.multi_turn import ContextMemoryGrader + +grader = ContextMemoryGrader(model=model) +result = await grader.aevaluate( + history=[ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What's my name?"}, + ], + response="Your name is Alice.", +) +``` + +--- + +### `multimodal/` — Vision & Image (requires VL model) + +| Class | Import | What it measures | +|-------|--------|-----------------| +| `TextToImageGrader` | `openjudge.graders.multimodal.text_to_image` | Text-image alignment | +| `ImageCoherenceGrader` | `openjudge.graders.multimodal.image_coherence` | Image sequence coherence | +| `ImageHelpfulnessGrader` | `openjudge.graders.multimodal.image_helpfulness` | Image usefulness for context | + +```python +from openjudge.models.qwen_vl_model import QwenVLModel +from openjudge.models.schema.qwen.mllmImage import MLLMImage +from openjudge.graders.multimodal.text_to_image import TextToImageGrader + +vl_model = QwenVLModel(model="qwen-vl-plus", api_key="sk-xxx") +grader = TextToImageGrader(model=vl_model) +result = await grader.aevaluate( + query="A red apple on a wooden table", + response=MLLMImage(url="https://example.com/image.jpg"), +) +``` + +--- + +## LLMGrader — Custom Prompt Grader + +Use `LLMGrader` directly when no built-in grader fits. Provide a template string with +`{placeholder}` variables that match your `aevaluate()` kwargs. + +```python +from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.schema import GraderMode + +grader = LLMGrader( + model=model, + name="helpfulness", + mode=GraderMode.POINTWISE, + template=""" +You are an evaluation assistant. + +Query: {query} +Response: {response} + +Rate the helpfulness of the response on a scale of 0.0 to 1.0. +Respond in JSON: {{"score": , "reason": ""}} +""", +) + +result = await grader.aevaluate( + query="How do I reverse a list in Python?", + response="Use list.reverse() or reversed().", +) +# result.score, result.reason +``` + +### Listwise (ranking) mode + +```python +ranking_grader = LLMGrader( + model=model, + name="quality_rank", + mode=GraderMode.LISTWISE, + template=""" +Rank the following responses to the query from best (1) to worst. + +Query: {query} +Response 1: {response_1} +Response 2: {response_2} + +Respond in JSON: {{"rank": [, ], "reason": ""}} +""", +) + +result = await ranking_grader.aevaluate( + query="Explain gravity", + response_1="Gravity is a fundamental force...", + response_2="Things fall down.", +) +# result.rank e.g. [1, 2] → response_1 is better +``` + +--- + +## FunctionGrader — Pure Python Evaluation + +Use when the scoring logic is deterministic and requires no LLM. + +```python +from functools import partial +from openjudge.graders.function_grader import FunctionGrader +from openjudge.graders.schema import GraderScore, GraderMode + +def length_check(response: str, min_words: int = 10) -> GraderScore: + word_count = len(response.split()) + score = 1.0 if word_count >= min_words else word_count / min_words + return GraderScore( + name="length_check", + score=score, + reason=f"Response has {word_count} words (min: {min_words})", + ) + +# Option A: use functools.partial to bake in extra params +grader = FunctionGrader( + func=partial(length_check, min_words=20), + name="length_check", + mode=GraderMode.POINTWISE, +) +result = await grader.aevaluate(response="Short answer.") + +# Option B: pass extra params directly in aevaluate() +grader2 = FunctionGrader(func=length_check, name="length_check", mode=GraderMode.POINTWISE) +result2 = await grader2.aevaluate(response="Short answer.", min_words=20) +``` + +> **Note:** Extra `**kwargs` passed to `FunctionGrader(...)` at construction time are stored in `grader.kwargs` but are **not** automatically forwarded to `func`. Use `functools.partial` (Option A) or pass them directly to `aevaluate()` (Option B). + +### Decorator syntax + +```python +@FunctionGrader.wrap +def exact_match(response: str, reference: str) -> GraderScore: + score = 1.0 if response.strip() == reference.strip() else 0.0 + return GraderScore(name="exact_match", score=score, reason="") + +grader = exact_match(name="exact_match", mode=GraderMode.POINTWISE) +``` + +--- + +## AgenticGrader — Tool-augmented Evaluation + +Use when the evaluation itself requires external tools (e.g., web search to verify facts). + +```python +from openjudge.agentic import ReActAgent +from openjudge.graders.agentic_grader import AgenticGrader + +# Step 1: build agent with tools +agent = ReActAgent( + model={"model": "gpt-4o", "api_key": "sk-..."}, + tools=[WebSearchTool()], # any BaseTool implementation + max_iterations=10, +) + +# Step 2: create grader +grader = AgenticGrader( + agent=agent, + name="fact_check", + template=""" +Verify the factual accuracy of the response using web search if needed. + +Query: {query} +Response: {response} + +Return JSON: {{"score": <0.0-1.0>, "reason": ""}} +""", +) + +result = await grader.aevaluate( + query="When was Python first released?", + response="Python was first released in 1991.", +) +``` + +--- + +## Custom Grader — Extend BaseGrader + +```python +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderMode, GraderScore + +class KeywordGrader(BaseGrader): + def __init__(self, keywords: list[str], **kwargs): + super().__init__(name="keyword_grader", mode=GraderMode.POINTWISE, **kwargs) + self.keywords = keywords + + async def _aevaluate(self, response: str, **kwargs) -> GraderScore: + hits = sum(1 for kw in self.keywords if kw.lower() in response.lower()) + score = hits / len(self.keywords) + return GraderScore( + name=self.name, + score=score, + reason=f"{hits}/{len(self.keywords)} keywords found", + ) + + @staticmethod + def get_metadata(): + return {"description": "Checks keyword presence in response"} + +grader = KeywordGrader(keywords=["Python", "list", "reverse"]) +result = await grader.aevaluate(response="Use list.reverse() in Python.") +``` diff --git a/ajet/copilot/openjudge/pipeline.md b/ajet/copilot/openjudge/pipeline.md new file mode 100644 index 0000000..e980edb --- /dev/null +++ b/ajet/copilot/openjudge/pipeline.md @@ -0,0 +1,307 @@ +# Pipeline Reference + +The pipeline layer handles batch evaluation: running graders over datasets, +controlling concurrency, combining multiple grader scores, and stabilizing +noisy LLM evaluations. + +--- + +## GradingRunner + +`GradingRunner` is the main entry point for batch evaluation. +It runs all configured graders over a dataset concurrently. + +### Constructor + +```python +from openjudge.runner.grading_runner import GradingRunner, GraderConfig + +runner = GradingRunner( + grader_configs, # Dict[str, grader | (grader, mapper) | GraderConfig] + max_concurrency=32, # max parallel API calls + aggregators=None, # optional aggregator(s) + show_progress=True, # tqdm progress bar + executor=None, # custom resource executor (rarely needed) +) +``` + +### Running evaluation + +```python +# Single dataset +results = await runner.arun(dataset) # RunnerResult + +# Multiple datasets (shared concurrency pool) +all_results = await runner.arun_multiple_datasets([dataset_a, dataset_b]) +``` + +### Result structure + +``` +RunnerResult = Dict[str, List[GraderResult]] + +{ + "grader_a": [GraderScore(...), GraderScore(...), GraderError(...)], + "grader_b": [GraderScore(...), GraderScore(...), GraderScore(...)], +} +``` + +Each list is indexed the same as the input `dataset` list. + +--- + +## GraderConfig — Input Formats + +`grader_configs` accepts four equivalent formats: + +```python +from openjudge.runner.grading_runner import GraderConfig + +# Format 1: bare grader instance (most common) +configs = {"correctness": CorrectnessGrader(model=model)} + +# Format 2: tuple (grader, mapper) +configs = {"correctness": (CorrectnessGrader(model=model), {"query": "q", "response": "a"})} + +# Format 3: GraderConfig object +configs = {"correctness": GraderConfig(grader=CorrectnessGrader(model=model), mapper=...)} + +# Format 4: dict +configs = {"correctness": {"grader": CorrectnessGrader(model=model), "mapper": None}} +``` + +--- + +## Mapper — Field Name Translation + +Use a mapper when your dataset field names differ from what the grader expects. + +### Dict mapper (field rename) + +Mapping: **key = grader kwarg name**, **value = path in dataset** to read from. + +```python +# Dataset has "question" / "answer" but grader expects "query" / "response" +configs = { + "correctness": GraderConfig( + grader=CorrectnessGrader(model=model), + mapper={"query": "question", "response": "answer"}, + # grader kwarg → dataset key + ) +} +``` + +### Callable mapper (full transformation) + +```python +def my_mapper(sample: dict) -> dict: + return { + "query": sample["input"], + "response": sample["output"], + "reference_response": sample.get("gold", ""), + "context": " ".join(sample.get("docs", [])), + } + +configs = { + "correctness": GraderConfig(grader=CorrectnessGrader(model=model), mapper=my_mapper) +} +``` + +--- + +## Multiple Graders in One Run + +Run multiple graders over the same dataset in one pass: + +```python +from openjudge.graders.common.correctness import CorrectnessGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.common.hallucination import HallucinationGrader + +runner = GradingRunner( + grader_configs={ + "correctness": CorrectnessGrader(model=model), + "relevance": RelevanceGrader(model=model), + "hallucination": HallucinationGrader(model=model), + }, + max_concurrency=16, +) + +results = await runner.arun(dataset) +# results["correctness"][i], results["relevance"][i], results["hallucination"][i] +``` + +--- + +## WeightedSumAggregator — Combine Multiple Scores + +Produce a single composite score from multiple graders per sample. + +```python +from openjudge.runner.aggregator.weighted_sum_aggregator import WeightedSumAggregator + +aggregator = WeightedSumAggregator( + name="overall", + weights={ + "correctness": 0.5, + "relevance": 0.3, + "hallucination": 0.2, + }, +) + +runner = GradingRunner( + grader_configs={ + "correctness": CorrectnessGrader(model=model), + "relevance": RelevanceGrader(model=model), + "hallucination": HallucinationGrader(model=model), + }, + aggregators=aggregator, +) + +results = await runner.arun(dataset) +# results["overall"][i] ← WeightedSumAggregator result (GraderScore) +# results["correctness"][i], results["relevance"][i], ... ← individual scores +``` + +**Notes:** +- If `weights` is omitted, equal weights are used automatically. +- `GraderError` and `GraderRank` results are skipped in the weighted sum. +- Multiple aggregators can be passed as a list. + +### Custom aggregator + +```python +from openjudge.runner.aggregator.base_aggregator import BaseAggregator +from openjudge.graders.schema import GraderResult, GraderScore + +class MinScoreAggregator(BaseAggregator): + """Returns the minimum score across all graders.""" + + def __call__(self, grader_results: dict[str, GraderResult], **kwargs) -> GraderResult: + scores = [r.score for r in grader_results.values() if isinstance(r, GraderScore)] + if not scores: + return GraderScore(name=self.name, score=0.0, reason="No valid scores") + return GraderScore( + name=self.name, + score=min(scores), + reason=f"Min of {len(scores)} grader scores", + ) + +aggregator = MinScoreAggregator(name="min_score") +``` + +--- + +## Evaluation Strategies — Reduce LLM Noise + +Attach a strategy to any grader to call it multiple times and aggregate. + +### VotingEvaluationStrategy + +Run N times, return the most frequent score. Best for discrete scores (1–5). + +```python +from openjudge.evaluation_strategy import VotingEvaluationStrategy, MIN + +strategy = VotingEvaluationStrategy( + num_votes=5, # must be ≥ 2; odd numbers avoid ties + tie_breaker=MIN, # MIN | MAX | CLOSEST_TO_MEAN | custom callable +) + +grader = CorrectnessGrader(model=model, strategy=strategy) +``` + +### AverageEvaluationStrategy + +Run N times, return the mean score. Best for continuous scores. + +```python +from openjudge.evaluation_strategy import AverageEvaluationStrategy + +strategy = AverageEvaluationStrategy(num_evaluations=3) +grader = RelevanceGrader(model=model, strategy=strategy) +``` + +### DirectEvaluationStrategy (default) + +Call once, return result as-is. This is the default when no strategy is specified. + +```python +from openjudge.evaluation_strategy import DirectEvaluationStrategy + +grader = CorrectnessGrader(model=model, strategy=DirectEvaluationStrategy()) +``` + +--- + +## Concurrency Control + +`max_concurrency` limits simultaneous LLM API calls across all graders and samples. + +```python +runner = GradingRunner( + grader_configs={"correctness": grader}, + max_concurrency=8, # conservative for rate-limited APIs +) +``` + +The underlying `SemaphoreResourceExecutor` ensures the total number of in-flight +requests never exceeds `max_concurrency`, regardless of dataset size or number of graders. + +--- + +## Complete Pipeline Example + +```python +import asyncio +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.graders.common.correctness import CorrectnessGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.text.string_match import StringMatchGrader +from openjudge.runner.grading_runner import GradingRunner, GraderConfig +from openjudge.runner.aggregator.weighted_sum_aggregator import WeightedSumAggregator +from openjudge.evaluation_strategy import VotingEvaluationStrategy +from openjudge.graders.schema import GraderScore, GraderError + +model = OpenAIChatModel(model="qwen-plus", api_key="sk-xxx", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + +# Voting strategy for LLM-based graders +voting = VotingEvaluationStrategy(num_votes=3) + +dataset = [ + { + "query": "What is the capital of France?", + "response": "Paris", + "reference": "Paris", + "reference_response": "The capital of France is Paris.", + }, +] + +runner = GradingRunner( + grader_configs={ + "correctness": CorrectnessGrader(model=model, strategy=voting), + "relevance": RelevanceGrader(model=model, strategy=voting), + "exact_match": GraderConfig( + grader=StringMatchGrader(), + mapper={"response": "response", "reference_response": "reference"}, + ), + }, + aggregators=WeightedSumAggregator( + name="overall", + weights={"correctness": 0.5, "relevance": 0.3, "exact_match": 0.2}, + ), + max_concurrency=8, +) + +async def main(): + results = await runner.arun(dataset) + for grader_name, grader_results in results.items(): + for i, result in enumerate(grader_results): + if isinstance(result, GraderScore): + print(f"[{grader_name}][{i}] score={result.score:.3f}") + elif isinstance(result, GraderError): + print(f"[{grader_name}][{i}] ERROR: {result.error}") + +asyncio.run(main()) +``` diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md index 032eb47..c33b527 100644 --- a/ajet/copilot/train-complex-blackbox/SKILL.md +++ b/ajet/copilot/train-complex-blackbox/SKILL.md @@ -55,7 +55,7 @@ from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math diff --git a/ajet/copilot/write-passive-swarm-client/SKILL.md b/ajet/copilot/write-passive-swarm-client/SKILL.md new file mode 100644 index 0000000..ba23baf --- /dev/null +++ b/ajet/copilot/write-passive-swarm-client/SKILL.md @@ -0,0 +1,194 @@ +--- +name: write-passive-swarm-client +description: Create a passive swarm client that waits for user input instead of iterating through a dataset by itself. +license: Complete terms in LICENSE.txt +--- + + +## Introduction: + +Your task is to connect to an external agent and enable dynamic model tuning under the AgentJet reinforcement learning framework, this is very simple. + + +```txt +User --> Application Interface (WEB, TUI, GUI) --> Application Backend --> Fake vLLM (fake_vllm_endpoint.py, you need to write this) --> In fake vLLM, duplicate each request multiple times (on_user_submit_new_requests) --> Calculate relative reward (on_compute_relative_reward) --> Submit reward (swarm_client.end_episode) --> Select the item with highest reward --> Return to user via original path + +用户 --> Application界面(WEB, TUI, GUI) --> Application后端 --> 假vLLM(fake_vllm_endpoint.py, 你要写这个) --> 假vLLM中, 将每个请求复制多份(on_user_submit_new_requests) --> 计算相对奖励(on_compute_relative_reward) --> 提交奖励 (swarm_client.end_episode) --> 选取奖励最高的一项 --> 原路返回给用户 +``` + + +First, give the agent system a name based on the user's requirements, always place your code at ``tutorial/opencode_build_***`, for example `opencode_build_openclaw_agent`. + +Next, create the directory: +`tutorial/opencode_build_openclaw_agent` + +Then, create the Agent source files: +- `tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py` (Use `ajet/tuner_lib/experimental/oai_model_one2many.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.) +- `tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py` (Place `on_compute_relative_reward`.) +- `tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py` (Place `on_user_submit_new_requests`.) + +As an optional step, write a `tutorial/opencode_build_openclaw_agent/mock_user_request.py` to read and iterate the dataset (if any) and simulate user input automatically. For example, if the front application is a web chat app, try to check whether it has a cli to simulate user input. +```txt +`mock_user_request.py` --> dataset --> query1 --> Application Interface (WEB, TUI, GUI) --> Application Backend --> Fake vLLM (fake_vllm_endpoint.py, you need to write this) -> .... (same) .... + --> query2 --> Application Interface (WEB, TUI, GUI) --> Application Backend --> Fake vLLM (fake_vllm_endpoint.py, you need to write this) -> .... (same) .... + --> query3 --> Application Interface (WEB, TUI, GUI) --> Application Backend --> Fake vLLM (fake_vllm_endpoint.py, you need to write this) -> .... (same) .... + --> query4 --> Application Interface (WEB, TUI, GUI) --> Application Backend --> Fake vLLM (fake_vllm_endpoint.py, you need to write this) -> .... (same) .... + +.... .... .... +``` + + +## Training and Debugging Instructions + +Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands. +- First, help the user write `agent_run.py` and `agent_roll.py`. +- Then, write clear instructions to guide the user through training (`readme.md`). + +Your task is then complete. + +Below are some reference materials. + --- + + # Using AgentJet Swarm to Train Your Agents + + AgentJet Swarm opens infinite possibilities for both LLM Agent engineers and LLM researchers. It is very easy to use and understand. In fact, there is no need for verbose explaination, code explains itself: + + ## (1/2) Launching a Swarm Server + + Simply run `ajet-swarm start` on a GPU server (or GPU cluster master), and we are done ✅. (You may ask: what about training config? Well, config will come from swarm client.) + + ![alt text](https://img.alicdn.com/imgextra/i4/O1CN01bm585R20h63S9NSSy_!!6000000006880-2-tps-1649-765.png) + + Notes: + 1. launch server together with a swarm monitor: + ```bash + (ajet-swarm start &> ajet-swarm-server.log) & (ajet-swarm overwatch) + ``` + + 2. overwatch swarm status with url: + ```bash + ajet-swarm overwatch --swarm-url=http://localhost:10086 + ``` + + 3. changing customized port (default port is 10086): + ```bash + ajet-swarm start --swarm-port=10086 + ``` + + 4. if you are using a multi-node cluster to train huge models, make sure you have already set up the ray cluster before you hit `ajet-swarm start`. + + + + The swarm server can be in the following states and transition between them as follows: + - **OFFLINE**: The swarm server is started but does not load any models or perform any training. It enters this state directly after startup. Additionally, it transitions to this state upon receiving a `stop_engine` command from (any) client while in any other state. + - **BOOTING**: The swarm server enters this state upon receiving a configuration followed by an explicit `begin_engine` command. In this state, it loads model parameters, initializes FSDP, and initializes vLLM. + - **ROLLING**: The swarm server enters this state automatically after completing **BOOTING** or after finishing the **WEIGHT_SYNCING** state. This represents the sampling phase. + - **ROLLING_POST**: When the swarm server determines that the sample pool is sufficient for proceeding to the next policy gradient step, it automatically transitions to this state. While in this state, ongoing episodes can still complete normally, but no new episodes can begin. + - **WEIGHT_SYNCING**: After being in the **ROLLING_POST** state, once all computational resources and threads related to ongoing episodes are reclaimed and cleaned up, the swarm server transitions to this state. During this stage, VERL completes the current policy gradient strategy update and then returns to the **ROLLING** state, repeating the cycle. + + + + ## (2/2) Launching Swarm Clients + + You can run any amount of swarm client: + - on any devices (macbook, workstation, the same machine you run swarm-server, **wherever you want**). + - at any time (before or in the middle of a training, **whenever you want**) + + But just remember: **ALL** swarm clients are equally authorized to order swarm server(s) **start or terminate** the training process. There is **no such role like Queen** in AgentJet Swarm. + + ### 2-1. Connecting to a swarm server and make it rock&roll + + The primary objective of swarm client is to make sure network connection is good. + Now, create a python script and start coding: + + ```python + from ajet.tuner_lib.experimental.swarm_client import SwarmClient + REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url + swarm_worker = SwarmClient(REMOTE_SWARM_URL) + ``` + + Secondly, generate a configuration (basically VERL yaml, but slightly different), **connect** to swarm server and then tell the swarm server **which model to train**, etc. When configuration is ready, tell engine to read yaml and begin VERL training cycles with `auto_sync_train_config_and_start_engine`. + + ```python + LOCAL_GRPO_N = 32 + yaml_job = AgentJetJob( + experiment_name="math_gsm8k_grpo", + algorithm="grpo", + n_gpu=4, + model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct', + batch_size=LOCAL_GRPO_N, + num_repeat=4, + ) + # hint: you can `yaml_job.dump_job_as_yaml('./config.yaml')` to take a look at the full configuration + # hint: you can `yaml_job.build_job_from_yaml('./config.yaml')` to load yaml configuration as override. (there are some configurations that must be edited from yaml) + swarm_worker.auto_sync_train_config_and_start_engine(yaml_job) + ``` + + An “episode” is the atomic unit of rollout work exchanged between client and server. The client does not “create” episodes; it **claims** episodes that the server has already registered (internally created by the training engine / runners). + + At a high level, one episode looks like: + + 1) **Claim** an episode (`begin_episode`) + - Client calls `begin_episode(...)`, which blocks/retries until server is in **ENGINE.ROLLING** and there is an available episode. + - On success, you receive: + - `episode_uuid`: the episode identifier + - `OpenaiBaseUrlAndApiKey(base_url, api_key, episode_uuid)`: credentials for OpenAI-compatible requests + + 2) **Run your agent** (your code) + - Use `base_url` + `api_key` for all LLM calls during this episode. + - This matters: the server uses these credentials to route your requests to the correct runtime/model and to attribute the requests to the claimed `episode_uuid`. + + 3) **Finish** the episode (`end_episode`) or **discard** it (`abort_episode`) + - `end_episode(task, episode_uuid, workflow_output)` sends reward + metadata back to the server. + - `abort_episode(episode_uuid)` tells the server to drop this episode result and clean up. + + Minimal safe skeleton (always abort on exceptions): + + ```python + from ajet.schema.task import WorkflowOutput, Task + + episode_uuid, api = swarm_client.begin_episode( + discard_episode_timeout=600, + episode_type="train", + ) + try: + workflow_output: WorkflowOutput = execute_agent(task, api) # workflow_output contains reward + swarm_client.end_episode(task, episode_uuid, workflow_output) + return workflow_output.reward + except Exception: + swarm_client.abort_episode(episode_uuid) + raise + ``` + + WARNING: the `base_url` + `api_key` returned by `begin_episode` must be used for this specific episode, you must always ensure different episodes use their own `base_url` + `api_key`. + + Abort semantics (why it is safe for debugging): + + - When the server is **ENGINE.ROLLING**, `abort_episode` typically **reverts** the episode back to the unclaimed pool, so other clients can pick it up. + - When the server is in **ENGINE.ROLLING_POST**, `abort_episode` will **delete** the episode record instead of re-queueing it, so weight syncing won’t be blocked by zombie episodes. + + Timeouts you should understand: + + - `discard_episode_timeout` (server-side): if an episode is **idle** (no LLM requests) for too long, the server can discard it. + - Client-side protection: the client records an internal max lifetime (currently `max_episode_time = 2 × discard_episode_timeout`). If you submit too late, `end_episode` will be converted into an `abort_episode` to avoid poisoning the pool. + + For very long or complex agents, consider periodically checking: + + ```python + if not swarm_client.can_continue_episode(episode_uuid): + # The server no longer considers this episode valid. + swarm_client.abort_episode(episode_uuid) + return None + ``` + + + One important thing to note is that before each episode begins, you need to call `begin_episode` to obtain the `base_url` and `api_key`. At the same time, you will receive an episode identifier, `episode_uuid`. The `swarm_worker` is thread-safe and does not hold the state of the `episode`, so you can safely invoke multiple `begin_episode` calls concurrently. When your agent finishes running, remember to call `end_episode` to send the reward signal back to the swarm server (with the `episode_uuid` parameter). Additionally, if you wish to discard an episode for reasons such as: + + - **Reward miscalculation** + - **External API out of credit** + - **Debugging** + - **Evaluation testing** + - **Mid-training, checking the training results with a test case** + - **An unexpected issue arises and this episode needs to be filtered** + + it’s simple: just replace `end_episode` with `abort_episode`. diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index 4fe28c5..3fd6fc4 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -136,7 +136,7 @@ Below are some reference materials. Now, create a python script and start coding: ```python - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url swarm_worker = SwarmClient(REMOTE_SWARM_URL) ``` @@ -364,7 +364,7 @@ Below are some reference materials. ```python from ajet.copilot.job import AgentJetJob - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete + from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 23ca32b..7684747 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -256,7 +256,6 @@ critic: override_config: {} path: ~/models/deepseek-llm-7b-chat target_modules: all-linear - tokenizer_path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct trust_remote_code: false use_remove_padding: false use_shm: false diff --git a/ajet/launcher.py b/ajet/launcher.py index d3a5b4c..71d8683 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -153,7 +153,7 @@ def start_swarm_server(env, config): assert config.ajet.enable_interchange_server, ( "Please enable_interchange_server in config to start swarm server." ) - from ajet.tuner_lib.experimental.as_oai_model_server import ( + from ajet.tuner_lib.experimental.oai_model_server import ( start_interchange_server, ) diff --git a/ajet/schema/convertion.py b/ajet/schema/convertion.py index 408bbcd..d0d384e 100644 --- a/ajet/schema/convertion.py +++ b/ajet/schema/convertion.py @@ -3,7 +3,6 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage from agentscope.model import ChatResponse as AgentScopeChatResponse -from openai.types.completion_usage import CompletionUsage from typing import List, Type from agentscope.message import TextBlock, ToolUseBlock from agentscope._utils._common import _json_loads_with_repair @@ -23,18 +22,11 @@ def convert_llm_proxy_response_to_oai_response(llm_proxy_response): choice = Choice( index=0, message=message, - finish_reason="stop", + finish_reason=llm_proxy_response.get("finish_reason", "stop"), ) # Calculate token usage if tokens are available - usage = None - if "tokens" in llm_proxy_response and llm_proxy_response["tokens"]: - completion_tokens = len(llm_proxy_response["tokens"]) - usage = CompletionUsage( - prompt_tokens=0, # Not available in llm_proxy_response - completion_tokens=completion_tokens, - total_tokens=completion_tokens, - ) + usage = llm_proxy_response.get("usage", None) return ChatCompletion( id=llm_proxy_response.get("request_id", "chatcmpl-default"), diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index 5e78907..3f31418 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -65,8 +65,6 @@ def __init__( token_arr=[], token_begin_index=-1, token_end_index=-1, - clip=False, - clip_token_limit=8192, tokenizer: PreTrainedTokenizer = None, # type: ignore token_generator="manual", build_from_uuid="", @@ -85,9 +83,8 @@ def __init__( self.token_begin_index = token_begin_index self.token_end_index = token_end_index self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE - self._content_for_future = "" + self._content_for_compare = "" self._info = "" - self.clip = clip self.tools = tools self.tool_calls = tool_calls self.tool_call_id = tool_call_id @@ -101,14 +98,8 @@ def __init__( self.manual_loss_mask_override = [] self.lack_normal_eos = False - if not clip: - self.generate_content_for_future(tokenizer=None, clip=False) - else: - self.generate_content_for_future( - tokenizer=tokenizer, - clip=True, - clip_token_limit=clip_token_limit, - ) + self.generate_content_for_compare(tokenizer=None) + self.eos_token_id = tokenizer.eos_token_id if token_generator == "auto": @@ -127,9 +118,9 @@ def auto_tokenize(self, tokenizer, tools): if not self.first_message: self.token_arr = self.auto_tokenize_non_first_message(tokenizer=tokenizer, tools=tools) else: - auto_tokenize_target = { + auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_future, + "content": self.content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -144,9 +135,9 @@ def auto_tokenize(self, tokenizer, tools): def auto_tokenize_non_first_message(self, tokenizer, tools): try: # completion_token_arr will contain generation_prompt header - auto_tokenize_target = { + auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_future, + "content": self.content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -160,7 +151,7 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): ) except Exception as e: raise ValueError( - f"Cannot tokenize {self.role} --- {self.content_for_future}, \n\n Error: {e}" + f"Cannot tokenize {self.role} --- {self.content_for_compare}, \n\n Error: {e}" ) self.token_arr, _ = self.get_inc_simple( text_frag_from=ajet_apply_chat_template( @@ -175,12 +166,12 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): return self.token_arr @property - def content_for_future(self): - if self._content_for_future == "": + def content_for_compare(self): + if self._content_for_compare == "": if not self.tool_calls: - logger.exception("content_for_future is not set, or previous llm output is empty!") - self._content_for_future - return self._content_for_future + logger.exception("content_for_compare is not set, or previous llm output is empty!") + self._content_for_compare + return self._content_for_compare @property def need_training(self): @@ -191,19 +182,9 @@ def need_training(self): ), f"author {self.author} is not identified" return self.author in NEED_TRAIN_AUTHORS - def generate_content_for_future(self, tokenizer, clip, clip_token_limit=-1): + def generate_content_for_compare(self, tokenizer): _content: str = self.content - if clip: - assert clip_token_limit > 0, "clip_token_limit must be set when clip is True" - n_token = len(tokenizer(_content, return_tensors="pt", padding=False)["input_ids"][0]) - if n_token > clip_token_limit: - # 8000 > 4000 - n_char = len(_content) # 10,000 - eps = 100 # token - preserve_percent = (clip_token_limit - eps) / n_token # 3900 / 8000 - n_char_to_preserve = int(n_char * preserve_percent) - _content = _content[:n_char_to_preserve] + "... truncate ..." - self._content_for_future = _content + self._content_for_compare = _content def get_loss_mask(self, blackout_token_combo): if self.need_training: @@ -315,7 +296,7 @@ def merge_tool_group(group, tokenizer): ) # re-compute token_arr auto_tokenize_targets = [ - {"role": msg.role, "content": msg.content_for_future} for msg in group + {"role": msg.role, "content": msg.content_for_compare} for msg in group ] merged.token_arr, _ = merged.get_inc_simple( text_frag_from=ajet_apply_chat_template( diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 723d9b5..7f65e9f 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -31,7 +31,7 @@ def start_swarm_server(env, config, port): # Set the port in the config config.ajet.interchange_server.interchange_server_port = port - from ajet.tuner_lib.experimental.as_oai_model_server import ( + from ajet.tuner_lib.experimental.oai_model_server import ( start_interchange_server, ) @@ -139,6 +139,24 @@ def main(): ) parser_overwatch.set_defaults(func=cmd_overwatch) + # Subcommand: top (alias for overwatch) + parser_top = subparsers.add_parser("top", help="Monitor the swarm server (alias for overwatch)") + parser_top.add_argument( + "--swarm-url", + type=str, + default="http://localhost:10086", + required=False, + help="Swarm server URL (default: http://localhost:10086)", + ) + parser_top.add_argument( + "--refresh-interval", + type=float, + default=2.0, + required=False, + help="Refresh interval in seconds (default: 2.0)", + ) + parser_top.set_defaults(func=cmd_overwatch) + args = parser.parse_args() if not hasattr(args, 'func'): diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index ced9cf1..9aae129 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -84,11 +84,11 @@ async def llm_chat_verl( add_generation_prompt=True, tokenize=False, ) - prompt_ids = self.tokenizer(prompt_text)["input_ids"] + prompt_token_ids = self.tokenizer(prompt_text)["input_ids"] final_res = await self.async_rollout_manager.generate( request_id=request_id, - prompt_ids=prompt_ids, + prompt_ids=prompt_token_ids, sampling_params=updated_sampling_params, ) @@ -132,11 +132,27 @@ async def llm_chat_verl( if decoded_text is None: decoded_text = "" + max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn + max_model_len: int = self.config.ajet.rollout.max_model_len + max_seq_length: int = max_model_len - max_response_length_in_one_turn + if len(prompt_token_ids) >= max_seq_length: + finish_reason = "length" + else: + finish_reason = "stop" + if tool_calls: + finish_reason = "tool_calls" + usage = { + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_array), # type: ignore + "total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore + } return { "role": "assistant", "request_id": request_id, "content": decoded_text, "tool_calls": tool_calls, + "finish_reason": finish_reason, + "usage": usage, "tokens": [ TokenAndProb( token_id=token_id, @@ -223,13 +239,30 @@ async def main(): # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}") logger.warning(f"Bad toolcall discovered: {content}") + tool_calls = message.get("tool_calls", []) + max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn + max_model_len: int = self.config.ajet.rollout.max_model_len + max_seq_length: int = max_model_len - max_response_length_in_one_turn + if len(prompt_token_ids) >= max_seq_length: + finish_reason = "length" + else: + finish_reason = "stop" + if tool_calls: + finish_reason = "tool_calls" + usage = { + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(response.choices[0].token_ids), # type: ignore + "total_tokens": len(prompt_token_ids) + len(response.choices[0].token_ids), # type: ignore + } return { "role": "assistant", "request_id": response.id, "content": content, "prompt_text": prompt_text, "prompt_token_ids": prompt_token_ids, - "tool_calls": message.get("tool_calls", []), + "tool_calls": tool_calls, + "finish_reason": finish_reason, + "usage": usage, "tokens": [ TokenAndProb( token_id=token, @@ -338,7 +371,7 @@ async def run_infer( if token_overflow: # ajet_action_when_overflow = self.config.ajet.rollout.ajet_action_when_overflow # cannot proceed due to context overflow - return self.construct_overflow_response() + return self.construct_overflow_response(info) # else: # otherwise, for abnormal output, can still proceed, but we do not track output anymore @@ -350,12 +383,13 @@ async def run_infer( return llm_output - def construct_overflow_response(self): + def construct_overflow_response(self, info): return { "role": "assistant", "request_id": "overflow_response", - "content": "ajet_proxy: Exceeded max model context length.", + "content": f"AgentJet: Exceeded max model context length. {info}", "tool_calls": None, + "finish_reason": "length", "tokens": [], } diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index 9e4a5c9..5810d3a 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -66,11 +66,11 @@ def register_episode_and_wait_output( while True: # : - # : ajet/tuner_lib/experimental/as_swarm_server.py + # : ajet/tuner_lib/experimental/swarm_server.py # : socket.send_string(workflow_output.model_dump_json()) # : workflow_output: WorkflowOutput # : - # : ajet/tuner_lib/experimental/as_swarm_server.py + # : ajet/tuner_lib/experimental/swarm_server.py # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" try: diff --git a/ajet/tuner.py b/ajet/tuner.py index 45a5442..54b2b35 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -171,7 +171,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker: def _enable_interchange_server(self, llm_inference_fn): # experimental reverse proxy start if self.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient + from ajet.tuner_lib.experimental.oai_model_client import InterchangeClient self.interchange_client = InterchangeClient( episode_uuid=self.context_tracker.episode_uuid, context_tracker=self.context_tracker, diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/oai_model_client.py similarity index 98% rename from ajet/tuner_lib/experimental/as_oai_model_client.py rename to ajet/tuner_lib/experimental/oai_model_client.py index aaecde5..89b3866 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/experimental/oai_model_client.py @@ -9,7 +9,7 @@ from loguru import logger from typing import TYPE_CHECKING -from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest +from ajet.tuner_lib.experimental.oai_model_server import InterchangeCompletionRequest from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket from ajet.tuner_lib.experimental.interchange_utils import DEBUG @@ -107,7 +107,7 @@ def _begin_service_threading(self): try: # : - # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : ajet/tuner_lib/experimental/oai_model_server.py # : socket.send_string(int_req.model_dump_json()) # : InterchangeCompletionRequest object in JSON string format message = self.socket.recv_string() @@ -165,7 +165,7 @@ def _begin_service_threading(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)") # - # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : ajet/tuner_lib/experimental/oai_model_server.py # : result_str = socket.recv_string() self.socket.send_string(result) diff --git a/ajet/tuner_lib/experimental/oai_model_one2many.py b/ajet/tuner_lib/experimental/oai_model_one2many.py new file mode 100644 index 0000000..80c9fda --- /dev/null +++ b/ajet/tuner_lib/experimental/oai_model_one2many.py @@ -0,0 +1,593 @@ +# -*- coding: utf-8 -*- + +""" +A one-to-many proxy server for LLM requests with reinforcement learning. + +This server implements a one-to-many request pattern where each user request +is processed by multiple parallel episodes, and the best response is selected +based on computed rewards. + +Architecture Overview: +--------------------- +1. Server Initialization: + - Connects to swarm server and syncs training config + - Starts the engine with specified AgentJetJob configuration + +2. Request Processing Flow: + - Receives LLM request and creates a Task + - Runs NUM_REPEAT parallel episodes + - Computes rewards for each episode response + - Returns the best response to user + +Usage: + python -m ajet.tuner_lib.experimental.oai_model_one2many + + +""" + +import os +import uuid +import random +import asyncio +import httpx +import json +import threading +from contextlib import asynccontextmanager +from typing import Dict, List, Optional, Any + +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger +from pydantic import BaseModel + +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from beast_logger import print_listofdict + + +# ============================================================================= +# Configuration Constants +# ============================================================================= + +SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +NUM_REPEAT = int(os.getenv("NUM_REPEAT", "8")) +TRAINING_OBJECTIVE = "我希望我的助手足够幽默" + +# ============================================================================= +# Global State +# ============================================================================= +USER_REQUEST_RECORD: List[Dict] = [] +REQUEST_COUNTER = 0 +swarm_client: Optional[SwarmClient] = None +ajet_job = AgentJetJob( + algorithm="grpo", + project_name="ajet-swarm", + experiment_name="test", + n_gpu=8, + model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct', + batch_size=16, + num_repeat=NUM_REPEAT, +) + +# ============================================================================= +# Pydantic Models +# ============================================================================= + +class EpisodeResult(BaseModel): + """Result from a single episode execution.""" + episode_uuid: str + response: Dict | List[bytes] + + +# ============================================================================= +# User Request Record Management +# ============================================================================= + +async def on_user_submit_new_requests(request_id: str, task: Task) -> None: + """ + Store user request record when a new request is submitted. + + This function maintains a chronological record of all user requests, + which can be used for tracking and debugging purposes. + + Args: + request_id: Unique identifier for this request + task: The Task object containing query information + """ + USER_REQUEST_RECORD.append({ + "request_id": request_id, + "task_id": task.task_id, + "query": task.main_query, + }) + # here, add some code to update OpenJudge grader according to new user preference (if user indicate any) + + +# ============================================================================= +# Reward Computation +# ============================================================================= + +async def on_compute_relative_reward(valid_results: List[EpisodeResult], all_answers: List[Dict]) -> List[float]: + """ + Compute relative rewards for all episode results. + + This function calculates a reward score for each episode response. + Currently implements a random reward generator as a placeholder. + + Future implementations should compare responses and generate + meaningful scores based on quality metrics. + + Args: + valid_results: List of successful episode results + + Returns: + List of reward scores in range [-1.0, 1.0] + """ + + # here, use OpenJudge to compute relative scores + rewards = [random.uniform(-1.0, 1.0) for _ in valid_results] + + # Add reward to each answer for logging + for answer, reward in zip(all_answers, rewards): + answer["reward"] = reward + print_listofdict(all_answers, header="on_compute_relative_reward") + + return rewards + + +# ============================================================================= +# Response Processing Utilities +# ============================================================================= + +def extract_assistant_message(resp: Dict | List[bytes]) -> Dict: + """ + Extract assistant message from response (handles both stream and non-stream). + + For streaming responses, accumulates delta content from all chunks. + For non-streaming responses, extracts the message directly. + + Args: + resp: Response data (list of chunks for stream, dict for non-stream) + + Returns: + Dictionary containing the assistant's message with role, content, + and optionally tool_calls + """ + if isinstance(resp, list): + # Stream response: accumulate delta content from all chunks + content_parts: List[str] = [] + tool_calls_map: Dict[int, Dict] = {} + + for raw in resp: + line = raw.decode() if isinstance(raw, bytes) else raw + if not line.startswith("data:"): + continue + + payload = line[len("data:"):].strip() + if payload == "[DONE]": + break + + try: + chunk = json.loads(payload) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + + # Accumulate content + if delta.get("content"): + content_parts.append(delta["content"]) + + # Accumulate tool calls + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + if idx not in tool_calls_map: + tool_calls_map[idx] = tc + else: + existing_args = tool_calls_map[idx].get("function", {}).get("arguments", "") + new_args = tc.get("function", {}).get("arguments", "") + tool_calls_map[idx].setdefault("function", {})["arguments"] = existing_args + new_args + except Exception: + pass + + msg: Dict[str, Any] = {"role": "assistant", "content": "".join(content_parts)} + if tool_calls_map: + msg["tool_calls"] = list(tool_calls_map.values()) + return msg + else: + # Non-stream: standard OpenAI response dict + return resp.get("choices", [{}])[0].get("message", {}) + + +# ============================================================================= +# HTTP Proxy Functions +# ============================================================================= + +async def proxy_chat_completion( + base_url: str, + api_key: str, + request: Request, + is_stream: bool = False +) -> Dict | List[bytes]: + """ + Proxy a chat completion request to the specified base URL. + + Args: + base_url: Target server base URL + api_key: API key for authentication + request: Original FastAPI request object + is_stream: Whether to use streaming response + + Returns: + Response data (dict for non-stream, list of chunks for stream) + """ + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Connection": "close", + } + + json_data = await request.json() + json_data["stream"] = is_stream + + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post( + f"{base_url}/chat/completions", + json=json_data, + headers=headers, + ) + resp.raise_for_status() + + if is_stream: + chunks = [] + async for line in resp.aiter_lines(): + if line.strip(): + chunks.append(line.encode() if isinstance(line, str) else line) + return chunks + else: + return resp.json() + + +# ============================================================================= +# Episode Execution +# ============================================================================= + +async def run_single_episode( + episode_index: int, + request: Request, + is_stream: bool, +) -> EpisodeResult: + """ + Run a single episode with the swarm client. + + Args: + episode_index: Index of this episode (for logging) + request: Original FastAPI request object + is_stream: Whether to use streaming response + + Returns: + EpisodeResult containing the episode UUID and response data + + Raises: + Exception: If the episode fails (after aborting the episode) + """ + assert swarm_client is not None, "Swarm client not initialized" + + loop = asyncio.get_event_loop() + episode_uuid, api_baseurl_key = await loop.run_in_executor( + None, lambda: swarm_client.begin_episode(discard_episode_timeout=120) # type: ignore[union-attr] + ) + + try: + response_data = await proxy_chat_completion( + base_url=api_baseurl_key.base_url, + api_key=api_baseurl_key.api_key, + request=request, + is_stream=is_stream, + ) + return EpisodeResult(episode_uuid=episode_uuid, response=response_data) + except Exception as e: + logger.error(f"Error in episode {episode_index}: {e}") + swarm_client.abort_episode(episode_uuid) + raise + + +async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeResult]: + """ + Run all episodes in parallel and collect valid results. + + Args: + request: Original FastAPI request object + is_stream: Whether to use streaming response + + Returns: + List of successful episode results + + Raises: + HTTPException: If all episodes fail + """ + episode_tasks = [ + run_single_episode(i, request, is_stream) + for i in range(NUM_REPEAT) + ] + + results = await asyncio.gather(*episode_tasks, return_exceptions=True) + + valid_results: List[EpisodeResult] = [] + for result in results: + if isinstance(result, Exception): + logger.warning(f"Episode failed: {result}") + elif isinstance(result, EpisodeResult): + valid_results.append(result) + + if not valid_results: + raise HTTPException(status_code=500, detail="All episodes failed") + + return valid_results + + +async def finalize_episodes( + task: Task, + valid_results: List[EpisodeResult], + rewards: List[float] +) -> None: + """ + Finalize all episodes by sending rewards to the swarm client. + + Args: + task: The Task object for this request + valid_results: List of successful episode results + rewards: List of computed rewards for each result + """ + assert swarm_client is not None, "Swarm client not initialized" + + loop = asyncio.get_event_loop() + + for episode_result, reward in zip(valid_results, rewards): + workflow_output = WorkflowOutput(reward=reward, metadata={}) + await loop.run_in_executor( + None, + lambda ep=episode_result, wo=workflow_output: swarm_client.end_episode( # type: ignore[union-attr] + task, ep.episode_uuid, wo + ), + ) + + +# ============================================================================= +# Main Request Handler +# ============================================================================= + +async def handle_one2many_request(request: Request, request_id: str) -> Dict | List[bytes]: + """ + Handle a one-to-many request by running multiple episodes in parallel. + + This is the main entry point for processing chat completion requests. + It orchestrates the entire flow: + 1. Parse request and create task + 2. Store request record + 3. Run parallel episodes + 4. Compute rewards + 5. Select and return best response + + Args: + request: FastAPI request object + request_id: Unique identifier for this request + + Returns: + Best response data (dict or list of stream chunks) + """ + # Parse request + json_data = await request.json() + is_stream = json_data.get('stream', False) + messages = json_data.get('messages', []) + message_latest = messages[-1] + user_query = str(message_latest.get("content", "") if isinstance(message_latest, dict) else "") + + # Create task and store request record + task = Task( + task_id=str(uuid.uuid4()), + main_query=user_query, + metadata={"TRAINING_OBJECTIVE": TRAINING_OBJECTIVE} + ) + await on_user_submit_new_requests(request_id, task) + + # Run all episodes in parallel + valid_results = await run_all_episodes(request, is_stream) + + # Extract answers and compute rewards + all_answers = [extract_assistant_message(r.response) for r in valid_results] + rewards = await on_compute_relative_reward(valid_results, all_answers) + + + # Finalize episodes with rewards + await finalize_episodes(task, valid_results, rewards) + + # Select and return best response + best_idx = rewards.index(max(rewards)) + return valid_results[best_idx].response + + +# ============================================================================= +# FastAPI Application Setup +# ============================================================================= + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager for startup and shutdown.""" + global swarm_client + global ajet_job + + logger.info(f"Initializing swarm client with URL: {SWARM_URL}") + swarm_client = SwarmClient(SWARM_URL) + + + logger.info(f"Syncing train config and starting engine with num_repeat={NUM_REPEAT}") + + def start_engine_background(): + try: + swarm_client.auto_sync_train_config_and_start_engine( # type: ignore[union-attr] + ajet_job, + force_restart=False, + ) + logger.info("Swarm engine is ready!") + except Exception as e: + logger.warning(f"Engine auto-sync skipped or failed: {e}") + + engine_thread = threading.Thread(target=start_engine_background, daemon=True) + engine_thread.start() + + yield + + +app = FastAPI(title="One-to-Many Proxy Server", lifespan=lifespan) + + +# ============================================================================= +# API Endpoints +# ============================================================================= + +@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) +async def one2many_proxy(request: Request, path: str): + """Main proxy endpoint for OpenAI-compatible API requests.""" + global REQUEST_COUNTER + + try: + if request.method == "POST" and path == "chat/completions": + REQUEST_COUNTER += 1 + request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}" + logger.info(f"Received chat completion request {request_id}") + + response_data = await handle_one2many_request(request, request_id) + + if isinstance(response_data, list): + # Stream response: replay recorded chunks + async def stream_chunks(chunks: List[bytes]): + for chunk in chunks: + yield chunk + b"\n\n" + + return StreamingResponse( + stream_chunks(response_data), + media_type="text/event-stream", + ) + + return response_data + else: + raise HTTPException(status_code=404, detail="Not Found") + + except httpx.TimeoutException: + logger.error(f"Timeout proxying {request.method} {path}") + raise HTTPException(status_code=504, detail="Gateway Timeout") + + except httpx.ConnectError: + logger.error(f"Connection error proxying {request.method} {path}") + raise HTTPException(status_code=502, detail="Bad Gateway") + + except Exception as e: + logger.exception(f"Unexpected error proxying {request.method} {path}: {e}") + raise HTTPException(status_code=500, detail="Internal Server Error") + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +@app.get("/requests") +async def get_requests(): + """Get all recorded user requests.""" + return {"requests": USER_REQUEST_RECORD} + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8090) + + +# ============================================================================= +# Test Script (for reference) +# ============================================================================= + +''' Test Script: + +# -*- coding: utf-8 -*- + +import os +import time +import requests +from typing import List, Dict + +PROXY_URL = os.getenv("PROXY_URL", "http://localhost:10010") + +MESSAGES = [ + [{"role": "user", "content": "Hello, how are you?"}], + [{"role": "user", "content": "Tell me a joke."}], + [{"role": "user", "content": "What's the weather like today?"}], + [{"role": "user", "content": "Write a short poem about coding."}], + [{"role": "user", "content": "What is Python?"}], + [{"role": "user", "content": "How do I learn machine learning?"}], + [{"role": "user", "content": "Tell me about your hobbies."}], + [{"role": "user", "content": "What's your favorite programming language?"}], + [{"role": "user", "content": "Explain what is an API."}], + [{"role": "user", "content": "Give me a recipe for pasta."}], +] + + +def send_chat_request(messages: List[Dict[str, str]], stream: bool = False) -> Dict: + """Send a chat completion request to the proxy server.""" + payload = { + "model": "test-model", + "messages": messages, + "stream": stream, + } + + try: + response = requests.post( + f"{PROXY_URL}/v1/chat/completions", + json=payload, + timeout=300, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.Timeout: + return {"error": "Request timed out"} + except requests.exceptions.RequestException as e: + return {"error": str(e)} + + +def main(): + print(f"Starting client, sending requests to {PROXY_URL}") + print("Press Ctrl+C to stop\n") + + request_count = 0 + + while True: + request_count += 1 + messages = MESSAGES[request_count % len(MESSAGES)] + + print(f"[Request {request_count}] Sending: {messages[0]['content'][:50]}...") + + result = send_chat_request(messages) + + if "error" in result: + print(f"[Request {request_count}] Error: {result['error']}") + else: + content = result.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"[Request {request_count}] Response: {content[:100]}...") + + print() + + time.sleep(5) + + +if __name__ == "__main__": + try: + health = requests.get(f"{PROXY_URL}/health", timeout=5) + print(f"Server health: {health.json()}\n") + except Exception as e: + print(f"Warning: Could not connect to server: {e}\n") + + main() + +''' diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/oai_model_server.py similarity index 86% rename from ajet/tuner_lib/experimental/as_oai_model_server.py rename to ajet/tuner_lib/experimental/oai_model_server.py index 367c808..2849e2b 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/oai_model_server.py @@ -70,7 +70,6 @@ def ep_key(episode_uuid: str) -> str: def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: - @asynccontextmanager async def lifespan(app: FastAPI): # Startup @@ -96,7 +95,7 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") # - # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : ajet/tuner_lib/experimental/oai_model_client.py # : message = self.socket.recv_string() socket.send_string(int_req.model_dump_json()) @@ -116,7 +115,7 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : - # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : ajet/tuner_lib/experimental/oai_model_client.py # : self.socket.send_string(result) # : ChatCompletion object in JSON string format result_str = socket.recv_string() @@ -152,6 +151,8 @@ async def mock_as_stream_response(result: ChatCompletion): """ content = result.choices[0].message.content if result.choices else "" role = result.choices[0].message.role if result.choices else "assistant" + result_id = result.id if result.id else uuid.uuid4().hex + result.id = "chatcmpl-" + result_id if not result_id.startswith("chatcmpl-") else result_id # try: # thinking = result.choices[0].message.reasoning_content # except: @@ -159,6 +160,7 @@ async def mock_as_stream_response(result: ChatCompletion): tool_calls = result.choices[0].message.tool_calls if result.choices and result.choices[0].message.tool_calls else None delta_tool_calls = [] # tool_calls: Optional[List[ChoiceDeltaToolCall]] = None finish_reason = result.choices[0].finish_reason + usage = result.usage if tool_calls: delta_tool_calls = [ChoiceDeltaToolCall( index=index, @@ -170,6 +172,18 @@ async def mock_as_stream_response(result: ChatCompletion): type=tc.type ) for index, tc in enumerate(tool_calls)] + def dump_chunk(chunk: ChatCompletionChunk) -> str: + dump = chunk.model_dump() + dump.pop("service_tier", None) + dump.pop("system_fingerprint", None) + if "usage" in dump and dump["usage"] is None: + dump.pop("usage", None) + # for each choice delta, if field (such as tool_calls) is empty, remove it from the delta to avoid confusion + for key in list(dump["choices"][0]["delta"].keys()): + if not dump["choices"][0]["delta"][key] and key != "content": # keep content even if it's empty + dump["choices"][0]["delta"].pop(key, None) + return f"data: {json.dumps(dump)}\n\n" + # First chunk with role first_chunk = ChatCompletionChunk( id=result.id, @@ -184,8 +198,7 @@ async def mock_as_stream_response(result: ChatCompletion): ) ] ) - dat = f"data: {first_chunk.model_dump_json()}\n\n" - yield dat + yield dump_chunk(first_chunk) # Content chunk content_chunk = ChatCompletionChunk( @@ -196,30 +209,28 @@ async def mock_as_stream_response(result: ChatCompletion): choices=[ ChunkChoice( index=0, - delta=ChoiceDelta(role=role, content=content, tool_calls=delta_tool_calls), + delta=ChoiceDelta(content=content, tool_calls=delta_tool_calls), finish_reason=None ) ] ) - dat = f"data: {content_chunk.model_dump_json()}\n\n" - yield dat - + yield dump_chunk(content_chunk) # Final chunk with finish_reason final_chunk = ChatCompletionChunk( id=result.id, model=result.model, created=result.created, object="chat.completion.chunk", + usage=usage, choices=[ ChunkChoice( index=0, - delta=ChoiceDelta(), - finish_reason=finish_reason + delta=ChoiceDelta(content=""), + finish_reason=finish_reason, ) ] ) - dat = f"data: {final_chunk.model_dump_json()}\n\n" - yield dat + yield dump_chunk(final_chunk) yield "data: [DONE]\n\n" @@ -261,6 +272,13 @@ async def chat_completions(request: Request, authorization: str = Header(None)): body = await request.json() new_req = ChatCompletionRequest.model_validate(body) + # Check if the first message is a system message, if not, add a default one + if new_req.messages: + first_msg = new_req.messages[0] + if first_msg.get("role") != "system": + logger.warning(f"First message role is '{first_msg.get('role')}', expected 'system'. Adding default system prompt.") + new_req.messages.insert(0, {"role": "system", "content": "You are a helpful assistant, your name is AgentJet."}) + # Create timeline UUID timeline_uuid = uuid.uuid4().hex @@ -269,7 +287,7 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # enable_swarm_mode if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import ep_key + from ajet.tuner_lib.experimental.swarm_server import ep_key assert shared_mem_dict is not None assert shared_mem_dict_lock is not None @@ -308,18 +326,37 @@ async def chat_completions(request: Request, authorization: str = Header(None)): loop = asyncio.get_running_loop() result = await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + if enable_swarm_mode: + assert shared_mem_dict is not None + shared_mem_dict["latest_llm_call"] = { + "input": body, + "output": result, + } + if original_stream: + result.model = "unknown_model" if not new_req.model else new_req.model return StreamingResponse(mock_as_stream_response(result), media_type="text/event-stream") return result if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import register_enable_swarm_mode_routes + from ajet.tuner_lib.experimental.swarm_server import register_enable_swarm_mode_routes + + @app.post("/replay_latest_llm_call") + async def replay_latest_llm_call(): + """Return the buffered latest LLM call result.""" + assert shared_mem_dict is not None + if ("latest_llm_call" not in shared_mem_dict) or shared_mem_dict["latest_llm_call"] is None: + raise HTTPException(status_code=404, detail="No LLM call has been made yet") + return shared_mem_dict["latest_llm_call"] + assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_swarm_mode is True." assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_swarm_mode is True." app, additional_coro = register_enable_swarm_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) + else: + additional_coro = None @@ -481,6 +518,6 @@ def start_interchange_server(config, blocking=False, env={}) -> int: if interchange_server: interchange_server.terminate() if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import kill_process_tree + from ajet.tuner_lib.experimental.swarm_server import kill_process_tree kill_process_tree(None, None) return -1 diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/swarm_client.py similarity index 95% rename from ajet/tuner_lib/experimental/as_swarm_client.py rename to ajet/tuner_lib/experimental/swarm_client.py index d4c3ff4..340980d 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/swarm_client.py @@ -5,6 +5,7 @@ import json import re import yaml +import tempfile from beast_logger import print_dict from typing import List, Tuple from loguru import logger @@ -67,12 +68,12 @@ def __init__(self, server_url: str): # better logging management self._last_second_print_buffer: dict[str, float] = {} self._begin_episode_lock = threading.Lock() + self._http_client_lock = threading.Lock() + self._http_client = self._refresh_http_client() # record last registered AgentJetJob self._agent_jet_job = None # throttle self._recent_seen_tasks = [] - # reuse httpx client to avoid creating SSL context repeatedly - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) def logger_info(self, message): # logger with de-duplication within 1 second to prevent log flooding @@ -96,21 +97,26 @@ def logger_info(self, message): def _refresh_http_client(self): """Refresh the HTTP client by closing the old one and creating a new one.""" - try: - self._http_client.close() - except Exception: - pass # Ignore errors when closing - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) - logger.info("HTTP client refreshed due to connection error") + with self._http_client_lock: + try: + self._http_client.close() + except Exception: + pass # Ignore errors when closing + try: + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=True) + except: + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=False) + logger.warning("swarm client httpx client refreshed.") + return self._http_client def _should_refresh_client_on_error(self, error: Exception) -> bool: """Check if an error suggests the HTTP client should be refreshed.""" error_msg = str(error).lower() return any(keyword in error_msg for keyword in [ + "broken pipe", "disconnected", "connection reset", "connection closed", - "broken pipe", "connection aborted" ]) @@ -234,7 +240,6 @@ def begin_episode(self, discard_episode_timeout=240, episode_type="train", throt def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: # max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`) max_episode_time = 8*discard_episode_timeout - status, status_json = self.get_engine_status() # warm up connection and log the status if status not in ["ENGINE.ROLLING"]: self.logger_info(f"Engine status is {status}. Waiting until ENGINE.ROLLING...") @@ -320,6 +325,8 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="t continue except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error claiming episode: {e}. Retrying ...") retry_delay = START_EPISODE_RETRY_DELAY continue @@ -398,6 +405,8 @@ def abort_episode(self, episode_uuid: str): logger.error(f"Failed to end episode {episode_uuid}") except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error ending episode: {e}") def sync_train_config(self, agent_jet_job: AgentJetJob): @@ -414,7 +423,9 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): try: config_dict = agent_jet_job.config.to_dict() yaml_str = yaml.safe_dump(config_dict, sort_keys=False) - + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(yaml_str) + logger.warning(f"Sync new training configuration: {f.name}") req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str) resp = self._http_client.post( @@ -498,6 +509,8 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= raise e except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error polling engine status: {e}") time.sleep(5) @@ -511,8 +524,8 @@ def get_engine_status(self) -> Tuple[str, dict]: raise_for_status_with_detail(resp) resp_json = resp.json() result = resp_json.get("engine_status", "unknown") - engine_status_detail = resp_json.get("engine_status_detail", None) - global_step = resp_json.get("global_step", None) + # engine_status_detail = resp_json.get("engine_status_detail", None) + # global_step = resp_json.get("global_step", None) if result == "unknown": logger.warning("get_engine_status: " + str(resp_json)) return result, resp_json @@ -525,7 +538,6 @@ def get_engine_status(self) -> Tuple[str, dict]: def can_continue_episode(self, episode_uuid: str) -> bool: if not episode_uuid: return False - try: req_obj = CanContinueEpisodeRequest( client_uuid=self.client_uuid, @@ -540,6 +552,8 @@ def can_continue_episode(self, episode_uuid: str) -> bool: data = CanContinueEpisodeResponse.model_validate(resp.json()) return data.can_continue except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error checking can_continue_episode: {e}") return False @@ -554,6 +568,8 @@ def get_episode_buffer(self) -> List[EpisodeStatus]: data = EpisodeBufferResponse.model_validate(resp.json()) return data.buffer except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error getting episode buffer: {e}") return [] @@ -632,6 +648,8 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: data = CurrentBatchRolloutPoolInformation.model_validate(resp.json()) return data except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error getting rollout statistics: {e}") return CurrentBatchRolloutPoolInformation() diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py similarity index 100% rename from ajet/tuner_lib/experimental/as_swarm_server.py rename to ajet/tuner_lib/experimental/swarm_server.py diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 13c4d69..9e6e284 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -3,9 +3,10 @@ import os import shutil import time +import yaml +import hydra.errors from functools import cache -import yaml from beast_logger import print_dict from hydra import compose, initialize from loguru import logger @@ -15,15 +16,44 @@ DEFAULT_DIR = "saved_experiments" +def fix_hydra_searchpath_and_create_copy_when_needed(yaml_fp): + """Fix Hydra search paths if they don't exist by trying with base directory.""" + abs_yaml_fp = os.path.abspath(yaml_fp) + with open(abs_yaml_fp, 'r', encoding='utf-8') as f: + yaml_content = yaml.safe_load(f) + if yaml_content and 'hydra' in yaml_content and 'searchpath' in yaml_content['hydra']: + base_dir = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + modified = False + for i, path in enumerate(yaml_content['hydra']['searchpath']): + if path.startswith('file://'): + rel_path = path[7:] + if not os.path.exists(rel_path): + fixed_path = os.path.join(base_dir, rel_path) + if os.path.exists(fixed_path): + logger.warning(f"Cannot find `{os.path.abspath(rel_path)}`, but find `{os.path.abspath(fixed_path)}`, override original config ...") + yaml_content['hydra']['searchpath'][i] = f'file://{fixed_path}' + modified = True + if modified: + with open(abs_yaml_fp + ".patch.yaml", 'w', encoding='utf-8') as f: + yaml.dump(yaml_content, f) + return abs_yaml_fp + ".patch.yaml" + return abs_yaml_fp + + def read_ajet_config(yaml_fp): """Load a Hydra configuration relative to this module.""" + yaml_fp = read_ajet_yaml_fp = fix_hydra_searchpath_and_create_copy_when_needed(yaml_fp) yaml_fp = os.path.relpath( yaml_fp, os.path.dirname(__file__) ) # do not try to understand this line, hydra is too weird def load_hydra_config(config_path: str, config_name: str) -> DictConfig: with initialize(config_path=config_path, version_base=None): - cfg = compose(config_name=config_name, overrides=[]) + try: + cfg = compose(config_name=config_name, overrides=[]) + except hydra.errors.MissingConfigException as e: + logger.error(f"Configuration default files not found (please check {read_ajet_yaml_fp})") + raise e return cfg dir_path = os.path.dirname(yaml_fp) diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index 9f8b771..f8003b1 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -392,7 +392,7 @@ def create_logo_panel(self, info: CurrentBatchRolloutPoolInformation) -> Text: return content def create_dashboard( - self, info: Optional[CurrentBatchRolloutPoolInformation] + self, info: Optional[CurrentBatchRolloutPoolInformation], init=False ) -> Layout: """Create the main dashboard layout""" layout = Layout() @@ -400,7 +400,7 @@ def create_dashboard( # Create header header = self.create_header(info) - if info is None: + if (info is None) and (not init): # Show error state error_panel = Panel( "[bold red]Failed to fetch data from server, please check your connection or simply wait a moment...[/bold red]\n" @@ -409,8 +409,19 @@ def create_dashboard( padding=(1, 2), ) layout.split_column(Layout(header, size=8), Layout(error_panel)) + elif (info is None) and (init): + # Initial state before first successful data fetch + welcome_panel = Panel( + "[bold green]Welcome to AgentJet Swarm Overwatch![/bold green]\n\n" + "Attempting to connect to server and fetch data...\n" + f"[dim]Target server: {self.server_url}[/dim]\n", + border_style="green", + padding=(1, 2), + ) + layout.split_column(Layout(header, size=8), Layout(welcome_panel)) else: # Check engine status and show logo for OFFLINE or BOOTING states + assert info is not None # for type checker if info.engine_status in ["ENGINE.OFFLINE", "ENGINE.BOOTING"]: # Hide tables and show logo logo_display = self.create_logo_panel(info) @@ -439,46 +450,116 @@ def create_dashboard( return layout - def run(self): - """Start the monitoring interface""" - self.console.clear() - try: - with Live( - self.create_dashboard(None), - console=self.console, - refresh_per_second=1, - screen=True, - ) as live: + def display_latest_llm_call(self): + while True: + response = httpx.post(f"{self.server_url}/replay_latest_llm_call", timeout=30.0) + structured_response = response.json() + self.console.clear() + if "input" not in structured_response or "output" not in structured_response: + self.console.print(f"[bold red]{structured_response}[/bold red]") + time.sleep(5) + continue + else: + input = structured_response["input"] + output = structured_response["output"] + self.console.print(f"\n[bold green]Input:[/bold green]\n{input}") + self.console.print(f"\n[bold green]Output:[/bold green]\n{output}") + hide_when_more_than_n_line_break = 4 + try: + input_items = "" + output_items = "" + for item in input['messages']: + role = item['role'] + content = item['content'] + if isinstance(content, list): + content = content[0].get('text', '') + if content.count('\n') >= hide_when_more_than_n_line_break: + content = content.replace('\n',' ')[:200] + " ....." + else: + content = content.replace('\n',' ') + input_items += f"[bold blue]@{role}:[/bold blue] {content}\n" + for item in output['choices']: + role = item['message']['role'] + content = item['message']['content'] + if content.count('\n') >= hide_when_more_than_n_line_break: + content = content.replace('\n',' ')[:200] + " ....." + else: + content = content.replace('\n',' ') + output_items += f"[bold red]@{role}:[/bold red] {content}\n" + self.console.print(f"\n-------------------------------------------------------------") + self.console.print(f"\n[bold green]Input Simlified:[/bold green]\n{input_items}") + self.console.print(f"\n[bold green]Output Simlified:[/bold green]\n{output_items}") + except: + pass + time.sleep(5) + + def choose_run(self) -> str: + mode = "overwatch" + # mode = "replay_latest_llm_call" + while True: + self.console.clear() + try: + if mode == "overwatch": + self.run() + elif mode == "replay_latest_llm_call": + self.display_latest_llm_call() + + except KeyboardInterrupt: + self.console.clear() + self.console.print("\n[bold yellow]Overwatch stopped by user[/bold yellow]") self.console.print( - "[bold green]Starting Swarm Overwatch...[/bold green]" + f"[dim]Total requests: {self.total_requests}, Errors: {self.error_count}[/dim]\n" ) - self.console.print(f"[dim]Press Ctrl+C to exit[/dim]\n") - time.sleep(1) - - while True: - try: - # Fetch latest data - info = self.fetch_pool_info() - # Update display - live.update(self.create_dashboard(info)) + self.console.print("\n[bold]Choose action:[/bold]") + self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch") + self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call") + self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit") + choice = input("\n> ").strip().lower() + + if choice == "o": + mode = "overwatch" + self.console.clear() + continue + elif choice == "t": + mode = "replay_latest_llm_call" + self.console.clear() + continue + else: + self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]") - # Wait for next refresh - time.sleep(self.refresh_interval) - - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(f"Error in monitoring loop: {e}") - time.sleep(self.refresh_interval) + def run(self): + """Start the monitoring interface""" - except KeyboardInterrupt: - self.console.clear() - self.console.print("\n[bold yellow]Overwatch stopped by user[/bold yellow]") + with Live( + self.create_dashboard(None, init=True), + console=self.console, + refresh_per_second=1, + screen=True, + ) as live: self.console.print( - f"[dim]Total requests: {self.total_requests}, Errors: {self.error_count}[/dim]\n" + "[bold green]Starting Swarm Overwatch...[/bold green]" ) + self.console.print(f"[dim]Press Ctrl+C to exit[/dim]\n") + time.sleep(1) + + while True: + try: + # Fetch latest data + info = self.fetch_pool_info() + + # Update display + live.update(self.create_dashboard(info)) + + # Wait for next refresh + time.sleep(self.refresh_interval) + + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Error in monitoring loop: {e}") + time.sleep(self.refresh_interval) def start_overwatch(server_url: str, refresh_interval: float = 2.0): @@ -490,7 +571,10 @@ def start_overwatch(server_url: str, refresh_interval: float = 2.0): refresh_interval: Refresh interval in seconds (default: 2.0) """ overwatch = SwarmOverwatch(server_url, refresh_interval) - overwatch.run() + try: + overwatch.choose_run() + except KeyboardInterrupt: + logger.info("Swarm Overwatch stopped by user") if __name__ == "__main__": diff --git a/docs/en/ajet-swarm-docker.md b/docs/en/ajet-swarm-docker.md index 15c054a..f2c51cb 100644 --- a/docs/en/ajet-swarm-docker.md +++ b/docs/en/ajet-swarm-docker.md @@ -123,7 +123,7 @@ Meanwhile, all VERL and training logs stream into `./swarmlog/swarm_server.log` From any machine (no GPU required) that can reach the server on port `10086`, run your Swarm Client: ```python -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient from ajet.copilot.job import AgentJetJob swarm_worker = SwarmClient("http://:10086") diff --git a/docs/en/example_openclaw.md b/docs/en/example_openclaw.md new file mode 100644 index 0000000..ba7e032 --- /dev/null +++ b/docs/en/example_openclaw.md @@ -0,0 +1,140 @@ +# OpenClaw x AgentJet:微调更懂用户的助手 + +## 龙虾来了 + +2025年末,GitHub上悄悄爬出一只"龙虾"。 + +没有发布会,没有预热,一个叫OpenClaw的开源项目从周末黑客的玩具,用三个月冲到了10万Star。它能接管你的邮件、日历、浏览器,能帮你订机票、写周报、自动回消息——一个跑在本地的全能AI管家。社区给它起了个绰号:龙虾。红色的logo,张牙舞爪的钳子,配上那股不管不顾替你把事办了的劲头,确实像。 + +Reddit上有人说"这是我第一次觉得AI真的在帮我干活而不是在陪我聊天",Hacker News的帖子底下挤满了部署教程和自动化脚本。2026年2月,OpenAI直接把它收购了。一只开源龙虾,就这么登堂入室。 + +然而,驯服一只龙虾并不容易。 + +有人一觉醒来发现硬盘被清空了,有人的邮件被OpenClaw删了个精光——喊停都没用,它不听。国家互联网应急中心专门发布了关于OpenClaw安全应用的风险提示。这些事故指向同一个根源:大模型在复杂Agent系统中,面对长上下文时的指令跟随能力仍然存在严重缺陷。龙虾力气很大,但它不总是听话。 + +解决这类问题最根本的手段是Agentic强化学习——用进化的思路,不断"规训"龙虾的行为边界。但不幸的是,传统LLM强化学习架构把采样和训练紧紧耦合在一起。训练器那条狭小的"甲板",根本装不下龙虾庞大的身躯——它背后是浏览器、终端、文件系统、多轮对话组成的复杂多智能体环境。传统框架对此毫无招架之力。 + +但没关系,训龙虾的工具来了。 + +## AgentJet:蜂群架构 + +阿里巴巴通义实验室和中科院联合研发的新一代多智能体训练框架AgentJet,采用了一种颠覆常规的"蜂群"架构。 + +核心思路很简单:把"训练"和"采样"彻底拆开。 + +在AgentJet的蜂群中,用户根据自己的硬件条件,自由搭建由两种节点构成的分布式训练网络: + +- "训练"节点跑在GPU服务器上,负责模型推理与梯度计算; +- "采样"节点可以跑在任何能连上蜂群的设备上——包括你的笔记本电脑——负责驾驭OpenClaw之类的智能体,源源不断地抽取训练所需的"数据燃料"。 + +这意味着什么? + +你不需要修改OpenClaw的任何一行代码,不需要退而求其次去用某个阉割版的衍生变体,就可以在自己的笔记本上微调、定制一只更懂你的龙虾。 + +更进一步,AgentJet支持将多个不同的LLM模型同时接入同一个多智能体系统的强化学习任务,实现真正意义上的非共享参数多智能体强化学习(MARL)。采样节点可以随时动态添加、移除、修改,构建出一张不受环境限制、能随时改Bug、能从外部环境崩溃中自愈的蜂群训练网络。 + +AgentJet完全开源,样例丰富,开箱即用。配套Token级别的追踪调试工具和逐版本训练性能追踪平台。还面向Vibe Coding开发者提供专用SKILLs,允许Claude Code等工具一键辅助智能体编排和训练调试。 + +![alt text](https://img.alicdn.com/imgextra/i1/O1CN01nCChgf1nNmLYJj2JZ_!!6000000005078-0-tps-3750-1395.jpg) + +## 三步训龙虾 + +整个流程只需要三步。 + +**1. 唤醒蜂群Server** + +不需要安装依赖,一条Docker命令启动训练引擎: + +```bash +docker run --rm -it -v ./swarmlog:/workspace/log -v ./swarmexp:/workspace/saved_experiments \ + -p 10086:10086 --gpus=all --shm-size=32GB ghcr.io/modelscope/agentjet:main bash -c "(ajet-swarm overwatch) & (NO_COLOR=1 LOGURU_COLORIZE=NO ajet-swarm start &>/workspace/log/swarm_server.log)" +``` + +**2. 启动蜂群Client** + +在你的笔记本上启动OpenAI模型接口拟态和用户奖励函数: + +```bash +git clone https://github.com/modelscope/agentjet.git && cd agentjet +pip install -e . +cd ./agentjet/tutorial/opencode_build_openclaw_agent +python fake_vllm_endpoint.py # 奖励只做演示用途 +``` + +**3. 放出龙虾,开始训练** + +启动OpenClaw,进入配置页面,把模型地址指向本地的拟态接口: + +设置 > 配置 > Models > Model Providers > `vllm:http://localhost:8090/v1` + +![配置模型地址](https://img.alicdn.com/imgextra/i2/O1CN01LK3R1W1Dy7bq8jLRR_!!6000000000284-2-tps-2450-1584.png) + +![配置模型参数](https://img.alicdn.com/imgextra/i2/O1CN01g9fUTP1JPD79lN87z_!!6000000001020-2-tps-1542-1067.png) + +然后正常使用OpenClaw提交问题: + +![提交问题](https://img.alicdn.com/imgextra/i1/O1CN013yqN5U1fpFApRMNzN_!!6000000004055-2-tps-3529-1594.png) + +反复提交,AgentJet会自动在后台寻找合适的时机执行训练: + +![自动训练](https://img.alicdn.com/imgextra/i3/O1CN01CBX7ug1TLDp2qPanE_!!6000000002365-2-tps-2756-1118.png) + +就这样。你用龙虾的过程,就是训练龙虾的过程。 + + +**4. 已经着急看训练效果了?** + +在分享给朋友和用户一起“训虾”之前,先让OpenClaw体验以下被3个人同时 ~~“撸猫”~~ “卤虾”的过程 + +```bash +# “卤虾” x1 +python mock_user_request.py & \ +# “卤虾” x2 +python mock_user_request.py & \ +# “卤虾” x3 +python mock_user_request.py +``` + + +**4. 查看训练曲线** + +等待一会,就可以观察龙虾的腌制情况了: + +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01jUvjKX1qefo37W5dV_!!6000000005521-2-tps-1740-1262.png) + + +## 帷幕之下 + +这套机制是怎么运转的?看一眼数据流就清楚了: + +``` +用户 + │ + ▼ +OpenClaw 界面 + │ + ▼ +OpenClaw 中枢 ──→ 假vLLM端点 (localhost:8090) + │ + ├──→ 将一个请求复制为多份,分发给模型生成多个候选回答 + │ + ├──→ OpenJudge 读取用户原始Query + │ + ├──→ OpenJudge 读取所有候选回答,计算相对奖励 + │ + └──→ 将奖励提交给 AgentJet 蜂群Server (localhost:10086) + │ + │ + 等待样本池“水线”达标 + │ + ▼ + 模型参数更新 +``` + +关键在中间那个"假vLLM端点"。它伪装成一个标准的OpenAI兼容API,OpenClaw完全无感知地向它发送请求。但在幕后,这个端点把每个请求复制成多份,让模型生成多个候选回答,再通过OpenJudge计算相对奖励,最后把奖励信号回传给AgentJet的训练引擎。 + +OpenClaw以为自己在正常调用模型,实际上它的每一次交互都在为自己的进化提供燃料。这就是蜂群架构的精妙之处——训练对智能体完全透明,不侵入、不修改、不感知。 + +值得一提的是,这种由用户实时发起任务参与训练的训练范式,可以归类为“被动”式训练。而AgentJet在主动式训练也非常强大,你可以同时启动多个蜂群client, +在多个完全不同的任务环境下采样,自由地将样本池调配成多个不同任务构成的“鸡尾酒”,然后使用这些样本计算更为鲁棒的策略梯度,避免“学会了这个,忘掉了那个”的情况发生,缓解遗忘现象。 +具体可以参考我们的Github文档和其他Blog。 diff --git a/docs/en/example_openclaw_draft.md b/docs/en/example_openclaw_draft.md new file mode 100644 index 0000000..f904219 --- /dev/null +++ b/docs/en/example_openclaw_draft.md @@ -0,0 +1,273 @@ + +------------- 以下是正文,补充、优化、完善 ----------- +# OpenClaw x AgentJet 微调更懂用户的助手 + + +第一部分(引子) +2025年末,一只“龙虾”火爆全球... (写一段引子,要足够吸引人,不要有AI味) + +然而,驯服一只龙虾并不容易。 +一不小心硬盘文件全部清空,被OpenClaw删光了邮件连喊停手都没用,国家互联网应急中心发布关于OpenClaw安全应用的风险提示…… +这暴露出大模型在复杂Agent系统中长上下文的指令跟随能力仍然有严重的局限。 + +解决这类问题的最根本手段就是 Agentic 强化学习,用进化的思路不断“规训”龙虾的行为规范。 +但不幸的是,传统 LLM RL 架构采样与训练紧耦合,LLM训练器狭小的“甲板”上无法容纳“龙虾”庞大的身躯, +因此对这类复杂的多智能体Agent系统毫无招架之力。 + +但没关系,训练OpenClaw的“标准”答案已经来了, +阿里巴巴通义实验室和中科院联合研发的新一代尖端多智能体训练框架 AgentJet 采用了一种颠覆规则的“蜂群”架构, +在这个架构中,用户不需要修改 OpenClaw 的任何一行代码,不需要退而求其次去使用 OpenClaw 的衍生变体, +就可以在自己的笔记本上微调、定制更懂自己的 OpenClaw 助手模型。 + +第二部分(切入) + +多智能体训练框架 AgentJet 的独特“蜂群”架构让微调任何智能体都变得前所未有的简单。 +用户可以自由地根据自己的硬件情况,搭建由“训练”和“采样”两种节点构成的分布式训练“蜂群”, +其中“训练”节点运行在GPU服务器上负责推理与梯度计算; +“采样”节点运行在能连接“蜂群”任意设备上,驾驭OpenClaw之类的智能体并源源不断地抽取训练的“数据燃料”。 + +一方面,在AgentJet中,研究者可以使用非常简单的代码,将多个不同LLM模型同时接入一个多智能体系统的RL训练任务中,实现真正意义的非共享参数多智能体强化学习(MARL); +另一方面,研究者可在任意设备(如笔记本电脑)上运行智能体直接参与训练, +也能随时动态添加、移除、修改智能体Rollout节点,构建不受环境限制、能随时改Bug、能从外部环境崩溃中自愈的蜂群训练网路。 +此外,AgentJet 完全开源,样例丰富,开箱即用,开放共建,并配套Token层级的追踪调试工具 & 逐版本训练性能追踪平台; +还面向Vibe Coding开发者提供相关技能(SKILLs),允许Claude Code等工具一键辅助您的智能体编排和调试训练工作。 + +第三部分(实操与实验) + +1. 启动蜂群Server:不需要安装依赖,一键唤醒 “蜂群” +```bash +docker run --rm -it -v ./swarmlog:/workspace/log -v ./swarmexp:/workspace/saved_experiments \ + -p 10086:10086 --gpus=all --shm-size=32GB ghcr.io/modelscope/agentjet:main bash -c "(ajet-swarm overwatch) & (NO_COLOR=1 LOGURU_COLORIZE=NO ajet-swarm start &>/workspace/log/swarm_server.log)" +``` + +2. 启动蜂群Client:启动OpenAI模型接口拟态 + 用户奖励函数倾向: +```bash +git clone https://github.com/modelscope/agentjet.git && cd agentjet +pip install -e . +cd ./agentjet/tutorial/opencode_build_openclaw_agent +python fake_vllm_endpoint.py +``` + +3. 启动龙虾,配置,然后开始训练: + +(3-1) 启动龙虾 +(3-2) 龙虾配置网页:设置 > 配置 > Models > Model Providers > vllm:http://localhost:8090/v1 +![alt text](https://img.alicdn.com/imgextra/i2/O1CN01LK3R1W1Dy7bq8jLRR_!!6000000000284-2-tps-2450-1584.png) +(3-3) 尝试提交问题 +![alt text](https://img.alicdn.com/imgextra/i1/O1CN013yqN5U1fpFApRMNzN_!!6000000004055-2-tps-3529-1594.png) +(3-4) 重复 (3-3) AgentJet会自动寻找合适的时机执行训练 +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01CBX7ug1TLDp2qPanE_!!6000000002365-2-tps-2756-1118.png) + + +第三部分(帷幕之下,AgentJet是如何做到的) + + +```txt + OpenClaw 配置网页:设置>配置>Models>Model Providers>vllm:http://localhost:8090/v1 + | + | +用户 --> OpenClaw界面 --> OpenClaw中枢 --> 假vLLM --> 一个请求复制多份 --> 计算相对奖励 --> 提交奖励给AgentJet + | | | | + | | | | + | | | (bash: `ajet-swarm start`, port 10086) + | | | + (bash: `python fake_vllm_endpoint`, port 8090) + | | + | | + | | + | | + OpenJudge OpenJudge + 读取用户Query 读取所有Query的所有答案并计算奖励 +``` + + +------------- 以上是正文,补充、优化、完善 ----------- + + + +------------- 以下是材料,用后删除 ----------- +------------- 以下是材料,用后删除 ----------- +------------- 以下是材料,用后删除 ----------- +------------- 以下是材料,用后删除 ----------- + +## 架构概述 + +```txt + OpenClaw 配置网页:设置>配置>Models>Model Providers>vllm:http://localhost:8090/v1 + | + | +用户 --> OpenClaw界面 --> OpenClaw中枢 --> 假vLLM --> 一个请求复制多份 --> 计算相对奖励 --> 提交奖励给AgentJet + | | | | + | | | | + | | | (bash: `ajet-swarm start`, port 10086) + | | | + (bash: `python -m ajet.tuner_lib.experimental.oai_model_one2many`, port 8090) + | | + | | + | | + | | + OpenJudge OpenJudge + 读取用户Query 读取所有Query的所有答案并计算奖励 +``` + + + +## 启动方法 + +### 1. 灵俊上启动swarm server,并ssh打通到龙虾 + +```bash +# agentjet 的 git checkout add-openclaw-training +ajet-swarm start # terminal 1 (start engine) +# 可选步骤 +ajet-swarm overwatch # terminal 2 (watch status) +``` + +```bash +# 如果直接在灵俊上跑龙虾,这步省略 +ssh -R 8090:localhost:8090 -p 22222 fuqingxu@server.ip.running.openclaw -N -o ServerAliveInterval=30 -o ServerAliveCountMax=3 +``` + +### 2. 龙虾服务器启动OpenJudge & AgentJet请求一转多服务 + +```bash +# agentjet 的 git checkout add-openclaw-training +python -m ajet.tuner_lib.experimental.oai_model_one2many +``` + +### 3. 启动龙虾,打开配置网页,然后 + +(3-1) 启动龙虾 +(3-2) 龙虾配置网页:设置 > 配置 > Models > Model Providers > vllm:http://localhost:8090/v1 +![alt text](https://img.alicdn.com/imgextra/i2/O1CN01LK3R1W1Dy7bq8jLRR_!!6000000000284-2-tps-2450-1584.png) +(3-3) 尝试提交问题 +![alt text](https://img.alicdn.com/imgextra/i1/O1CN013yqN5U1fpFApRMNzN_!!6000000004055-2-tps-3529-1594.png) +(3-4) 重复 (3-3) AgentJet会自动寻找合适的时机执行训练 +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01CBX7ug1TLDp2qPanE_!!6000000002365-2-tps-2756-1118.png) + +## 调试奖励: + +修改 `ajet.tuner_lib.experimental.oai_model_one2many` 中的 `on_user_submit_new_requests` 和 `on_compute_relative_reward` + + + +# OpenClaw 概况调研报告 + +## 一、总体概述 + +OpenClaw 是一个诞生于 2025 年的开源个人 AI 助手平台,其核心定位是"AI 智能体的操作系统"(OS for AI Agents)[[1]](https://openclaw.ai/blog/introducing-openclaw)。该项目最初以周末黑客项目(weekend hack)的形式起步,先后经历了 Clawdbot、Moltbot 等命名阶段,最终更名为 OpenClaw 以彰显其开源与社区驱动的本质[[1]](https://openclaw.ai/blog/introducing-openclaw)。项目在极短时间内积累了超过 10 万 GitHub Stars 和数百万访问量,展现出极强的社区吸引力[[1]](https://openclaw.ai/blog/introducing-openclaw)。2026 年 2 月,OpenClaw 被 OpenAI 收购,总部位于旧金山,归属商业/生产力软件行业[[2]](https://pitchbook.com/profiles/company/1318645-09)。 + +这一项目的核心理念可以概括为三个关键词:本地优先(local-first)、用户主权(user sovereignty)、多模型兼容(model-agnostic)。它不是一个简单的聊天机器人,而是一个能够自动化处理邮件、日历、浏览器操作并拥有持久记忆和完整系统访问权限的全功能 AI 代理平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 二、技术架构深度分析 + +### 2.1 基础技术栈 + +OpenClaw 构建于 Node.js(v22.12.0+)之上,支持 macOS、Windows(需 WSL2)和 Linux 三大操作系统[[3]](https://ppaolo.substack.com)。选择 Node.js 作为运行时并非偶然——其事件驱动、非阻塞 I/O 的特性天然适合处理多通道消息的并发场景,同时 JavaScript/TypeScript 生态的丰富性也降低了社区贡献的门槛。 + +### 2.2 分层架构设计 + +OpenClaw 的架构呈现出清晰的分层设计思路[[3]](https://ppaolo.substack.com): + +**通道适配层(Channel Adapters)**:这是系统的"感知层",负责对接 WhatsApp、Telegram、Discord、Slack、Teams、Twitch、Google Chat 等主流通讯平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。每个适配器处理该平台特有的认证协议、消息解析、访问控制和出站格式化。这种设计使得新增通道只需实现标准接口,无需改动核心逻辑。 + +**控制接口层(Control Interfaces)**:提供 Web UI、CLI、macOS 原生应用和移动端等多种交互方式[[3]](https://ppaolo.substack.com),确保用户可以在不同场景下管理和监控 AI 代理。 + +**网关控制平面(Gateway Control Plane)**:作为系统的"中枢神经",负责请求路由、负载均衡和全局策略执行[[3]](https://ppaolo.substack.com)。 + +**代理运行时(Agent Runtime)**:这是架构中最核心的部分,包含以下关键组件[[3]](https://ppaolo.substack.com): +- 会话解析(Session Resolution):识别和管理用户会话上下文 +- 上下文组装(Context Assembly):将历史对话、记忆、工具状态等信息组装为模型可理解的上下文 +- 执行循环(Execution Loop):驱动 AI 代理的思考-行动-观察循环 +- 系统提示词架构(System Prompt Architecture):管理和组合系统级指令 + +**数据存储层**:涵盖会话状态压缩(Session State Compaction)、记忆搜索(Memory Search)、存储索引(Storage Indexing)和嵌入向量提供者选择(Embedding Provider Selection)[[3]](https://ppaolo.substack.com)。会话状态压缩机制尤其值得关注——它解决了长对话场景下上下文窗口溢出的问题,通过智能摘要保留关键信息。 + +### 2.3 多代理协作能力 + +OpenClaw 支持多代理路由(Multi-Agent Routing)、代理间通信(Agent-to-Agent Communication)、定时任务(Scheduled Actions)和外部触发器(External Triggers)[[3]](https://ppaolo.substack.com)。这意味着用户可以构建由多个专业化 AI 代理组成的协作系统——例如一个代理负责邮件分类,另一个负责日程安排,第三个负责代码审查,它们之间可以相互通信和协调。 + +--- + +## 三、AI 模型支持生态 + +### 3.1 多供应商模型矩阵 + +OpenClaw 的模型支持策略体现了"不绑定单一供应商"的设计哲学,目前支持的模型包括[[4]](https://docs.openclaw.ai): + +| 供应商 | 模型 | +|--------|------| +| OpenAI | GPT-5.1, Codex | +| Anthropic | Claude Opus 4.6 | +| Google | Gemini 3 Pro | +| Z.AI | GLM 4.7 | +| Moonshot AI | Kimi K2.5 | +| MiniMax | M2.1 | +| 阿里云 | Qwen | +| 本地运行时 | Ollama | +| 其他 | OpenCode Zen, Synthetic 等 | + +### 3.2 战略意义分析 + +值得注意的是,OpenClaw 同时集成了美国和中国的 AI 模型[[4]](https://docs.openclaw.ai)[[5]](https://scmp.com)。这一策略具有多重意义: + +首先是成本优化——不同模型在不同任务上的性价比差异显著,用户可以为简单任务选择低成本模型,为复杂推理选择高端模型。其次是冗余保障——当某一供应商服务中断时,系统可以自动切换到备选模型。最后是能力互补——中美模型在中英文处理、代码生成、多模态理解等方面各有所长。 + +通过 Ollama 支持本地 LLM 运行[[4]](https://docs.openclaw.ai),OpenClaw 还为对数据隐私有极高要求的用户提供了完全离线的选项,这在企业级应用场景中尤为重要。 + +--- + +## 四、部署方案与硬件要求 + +### 4.1 云端部署 + +云端部署提供一键式快速启动,内置安全加固措施包括防火墙规则、非 root 执行和弹性扩缩容[[6]](https://help.apiyi.com)。优势在于零运维负担和快速上线,但需要承担月度费用,且数据不完全在用户控制之下。 + +### 4.2 本地部署 + +本地部署是 OpenClaw 的核心差异化优势所在。它提供完全的数据隐私保障、离线运行能力和深度定制空间,但对用户的技术能力和硬件配置有一定要求[[6]](https://help.apiyi.com): + +- CPU 推荐:AMD Ryzen 9 7950X 或 Intel Core i9-13900K +- GPU 推荐:NVIDIA RTX 4090 或 RTX 4080 + +这一硬件要求主要针对需要本地运行大语言模型的场景。如果仅使用云端 API 调用模型,硬件要求会大幅降低。 + +### 4.3 安全注意事项 + +安全最佳实践明确建议不要在主力工作机上运行 OpenClaw[[7]](https://safeclaw.io)。这一建议源于 OpenClaw 拥有强大的系统执行能力——包括浏览器自动化、文件系统访问和命令执行——一旦出现提示词注入攻击或配置失误,可能对主机系统造成影响。推荐使用独立的 homelab 服务器或 VPS 进行部署[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 五、安全体系 + +OpenClaw 的安全架构是多层次的[[3]](https://ppaolo.substack.com): + +- 网络安全(Network Security):传输层加密和网络隔离 +- 认证机制(Authentication):多因素身份验证 +- 通道访问控制(Channel Access Control):细粒度的平台级权限管理 +- 工具沙箱(Tool Sandboxing):限制 AI 代理可调用的系统能力 +- 会话边界(Session-based Boundaries):防止跨会话信息泄露 +- 提示词注入防御(Prompt Injection Defenses):抵御恶意输入攻击 +- 机器可检查安全模型(Machine-checkable Security Models):可形式化验证的安全策略[[1]](https://openclaw.ai/blog/introducing-openclaw) + +引入机器可检查安全模型这一点尤其前瞻——它意味着安全策略不仅是文档化的规则,而是可以被自动化工具验证和执行的形式化规范,这在 AI 代理安全领域属于较为领先的实践。 + +--- + +## 六、关键洞察与启示 + +**从周末项目到被 OpenAI 收购的增长路径**:OpenClaw 的发展轨迹揭示了 2025 年 AI 基础设施领域的一个重要趋势——开源 AI 代理框架正在成为大型 AI 公司的战略收购目标。OpenAI 收购 OpenClaw[[2]](https://pitchbook.com/profiles/company/1318645-09),本质上是在补齐其在"AI 代理运行时"层面的能力,从单纯的模型提供商向平台化方向延伸。 + +**本地优先 vs. 云端便利的张力**:OpenClaw 试图在数据主权和使用便利性之间找到平衡点。其双轨部署策略反映了市场的真实需求分化——企业和隐私敏感用户倾向本地部署,而个人开发者和快速原型场景更青睐云端方案。 + +**多模型策略的行业信号**:OpenClaw 广泛集成中美两国 AI 模型的做法[[5]](https://scmp.com),表明在实际应用层面,模型的地缘属性正在让位于实用性考量。这对整个 AI 应用生态的发展方向具有参考意义。 + +**安全作为一等公民**:在 AI 代理拥有系统级执行权限的背景下,OpenClaw 将安全提升到架构设计的核心位置[[7]](https://safeclaw.io),而非事后补丁。这种"安全左移"的理念值得同类项目借鉴。 + +--- + +## 七、结论 + +OpenClaw 代表了 2025 年 AI 代理平台发展的一个典型样本:以开源社区为驱动力,以本地部署和用户数据主权为核心卖点,以多模型兼容和多通道集成为功能支撑,以分层安全架构为信任基础。其从独立项目到被 OpenAI 收购的历程,既验证了 AI 代理基础设施的市场价值,也预示着这一领域正在从碎片化的开源探索走向平台化整合的新阶段。对于关注 AI 代理技术栈演进的开发者和技术决策者而言,OpenClaw 的架构设计和生态策略都具有重要的参考价值。 diff --git a/docs/en/example_train_multi_model.zh.md b/docs/en/example_train_multi_model.zh.md new file mode 100644 index 0000000..772a84f --- /dev/null +++ b/docs/en/example_train_multi_model.zh.md @@ -0,0 +1,209 @@ +# 非共享参数多智能体强化学习实战 + +在传统的多智能体强化学习(MARL)系统中,所有智能体通常共享同一套模型参数——这意味着无论有多少个智能体,它们都共用一个"大脑"。这种设计虽然简单,但在实际应用中存在明显的局限性:不同智能体可能需要不同规模的模型来执行不同复杂度的任务。AgentJet 的 Swarm 训练模式突破了这一限制,实现了真正的**非共享参数多智能体强化学习**。 + +## 背景:从"共享大脑"到"异构团队" + +在传统框架中训练多智能体系统时,研究者面临一个隐含假设:所有智能体必须共享同一个底层模型。这种设计源于大多数训练后端(如 VERL 和 TRL)的架构限制——它们通常只支持对单个 LLM 模型进行微调训练。 + +然而,这种"共享大脑"的设计在很多场景下并不经济: + +- **能力错配**:一个负责高层规划的 Agent 可能需要 32B 的大模型来保证推理质量,而负责具体执行的 Agent 用一个 7B 的小模型就足够了 +- **资源浪费**:用大模型处理简单任务是对计算资源的浪费 +- **训练信号单一**:所有智能体接受相同的奖励信号,难以针对各自的任务进行专门优化 + +AgentJet Swarm 模式通过部署多个独立的 Swarm Server,每个 Server 承载不同大小的模型,实现了真正的**异构多模型训练**。每个模型可以拥有独立的训练配置、奖励函数和优化目标。 + +## 示例场景:学术论文翻译工作流 + +让我们通过一个具体的例子来理解非共享参数多智能体强化学习的工作方式。本示例实现了一个三阶段的学术论文翻译工作流: + +```mermaid +graph LR + A[输入英文论文] --> B1(Agent 1: 粗翻译) + B1 --> B2(Agent 2: 检测专有名词) + B2 --> B3(Agent 3: 最终翻译) + B3 --> D[中文论文] + + B1 -.-> R1[7B 模型] + B2 -.-> R2[14B 模型] + B3 -.-> R3[7B 模型] + + R1 --> T1[翻译质量奖励] + R2 --> T2[检测质量奖励] + R3 --> T3[翻译质量奖励] +``` + +在这个工作流中: + +- **Agent 1(粗翻译)**:使用 7B 模型将英文论文初步翻译为中文 +- **Agent 2(检测专有名词)**:使用 14B 模型检测翻译中的专有名词错误(如术语翻译、缩写处理等) +- **Agent 3(最终翻译)**:使用 7B 模型根据检测结果生成最终的中文翻译 + +## 核心创新:独立奖励函数 + +传统方案中,所有智能体共享同一个奖励信号——无论哪个 Agent 产生输出,奖励都基于最终翻译质量来计算。这种设计存在一个根本问题:Agent 2(14B 模型)实际上是在为"最终翻译质量"而不是"检测质量"负责,这导致训练信号模糊,模型难以学到真正的检测能力。 + +本示例的创新之处在于为每个模型配置了**独立的奖励函数**: + +| 模型 | Agent 角色 | 奖励函数 | 评估重点 | +|------|-----------|---------|---------| +| 7B | Agent 1 & 3 | TranslationQualityGrader | 最终翻译质量(人称代词、缩写、语序、主语清晰度) | +| 14B | Agent 2 | ProperNounDetectionGrader | 专有名词检测质量(完整性、准确性、误报率) | + +这种设计的优势在于: + +1. **任务特异性训练**:每个模型学习其特定角色的最佳策略 +2. **信号清晰**:14B 模型直接学习"如何检测错误",而非"如何让最终翻译看起来更好" +3. **资源优化**:简单翻译任务使用小模型,复杂检测任务使用大模型 +4. **独立演化**:7B 和 14B 模型可以独立优化,互不干扰 + +## 系统架构 + +AgentJet 通过部署**两个独立的 Swarm Server** 来实现非共享参数训练: + +```mermaid +graph TB + subgraph "客户端 (Swarm Client)" + C[训练脚本] + end + + subgraph "Server 1: 7B 模型" + S1[Swarm Server
:10086] + M1[Qwen2.5-7B-Instruct] + S1 --> M1 + end + + subgraph "Server 2: 14B 模型" + S2[Swarm Server
:10087] + M2[Qwen2.5-14B-Instruct] + S2 --> M2 + end + + C -->|begin_episode| S1 + C -->|begin_episode| S2 + S1 -->|api_base_url + api_key| C + S2 -->|api_base_url + api_key| C + C -->|end_episode + reward_7b| S1 + C -->|end_episode + reward_14b| S2 +``` + +**架构说明**: + +- **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练 +- **Swarm Server 2 (端口 10087)**:承载 14B 模型,负责 Agent 2 的推理与训练 +- **Swarm Client**:运行在任何设备上,负责工作流编排和奖励计算 + +客户端代码只需要传入两个不同的 `api_baseurl_key`,分别对应两个模型: + +```python +def rollout(task): + # 从两个 Swarm Server 获取独立的 API 凭证 + episode_uuid_7b, api_baseurl_key_7b = swarm_worker_7b.begin_episode() + episode_uuid_14b, api_baseurl_key_14b = swarm_worker_14b.begin_episode() + + # 使用两个模型执行工作流 + workflow_output_7b, workflow_output_14b = execute_agent( + task, + api_baseurl_key_7b, + api_baseurl_key_14b + ) + + # 分别向两个 Server 报告各自对应的奖励 + swarm_worker_7b.end_episode(task, episode_uuid_7b, workflow_output_7b) + swarm_worker_14b.end_episode(task, episode_uuid_14b, workflow_output_14b) +``` + +## 奖励函数设计 + +### 7B 模型奖励:翻译质量评估 + +7B 模型(Agent 1 和 Agent 3)的奖励由 `TranslationQualityGrader` 计算,评估标准包括: + +- **第一人称代词使用**:禁止使用"我们",应使用"本研究"、"本文"等 +- **缩写翻译**:当有简洁中文表达时使用缩写(如 GWs 而非"引力波") +- **语序调整**:未按中文习惯调整句子结构 +- **主语清晰度**:主语缺失或不明确 +- **专有名词翻译**:领域术语翻译错误 + +评分范围 0-2 分,归一化到 [0, 1]。 + +### 14B 模型奖励:检测质量评估 + +14B 模型(Agent 2)的奖励由 `ProperNounDetectionGrader` 计算,评估标准包括: + +- **完整性**:是否检测到所有关键错误(第一人称代词、缩写问题、专有名词错误等) +- **准确性**:检测到的错误是否准确,纠正建议是否合理 +- **误报率**:是否将正确的翻译标记为错误 +- **JSON 格式**:输出是否为有效的 JSON 格式 + +同样采用 0-2 分的评分体系,归一化到 [0, 1]。 + +## 训练流程 + +整个训练流程如下: + +```mermaid +sequenceDiagram + participant Client as Swarm Client + participant Server7B as 7B Swarm Server + participant Server14B as 14B Swarm Server + + Client->>Server7B: begin_episode() + Client->>Server14B: begin_episode() + Server7B-->>Client: api_baseurl_key_7b + Server14B-->>Client: api_baseurl_key_14b + + Note over Client: Agent 1 (7B): 粗翻译 + Note over Client: Agent 2 (14B): 检测错误 + Note over Client: Agent 3 (7B): 最终翻译 + + Note over Client: 计算 7B 奖励: 翻译质量 + Note over Client: 计算 14B 奖励: 检测质量 + + Client->>Server7B: end_episode(reward_7b) + Client->>Server14B: end_episode(reward_14b) + + Server7B-->>Server7B: 策略梯度更新 (7B) + Server14B-->>Server14B: 策略梯度更新 (14B) +``` + +每个训练周期中: + +1. 客户端同时向两个 Swarm Server 请求 episode 资源 +2. 执行完整的工作流,获取两个模型的输出 +3. 分别计算两个奖励:7B 基于最终翻译质量,14B 基于检测质量 +4. 将各自的奖励汇报给对应的 Swarm Server +5. 两个 Server 独立执行策略梯度更新 + +## 训练曲线 + +![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png) + + +## 优势总结 + +与传统的单模型共享参数训练相比,非共享参数多智能体强化学习具有显著优势: + +| 特性 | 共享参数 | 非共享参数(本示例) | +|------|---------|-------------------| +| 模型配置 | 单一模型 | 7B + 14B 异构组合 | +| 奖励信号 | 统一奖励 | 任务特异性奖励 | +| 资源利用 | 低效(大模型处理简单任务) | 高效(按需分配) | +| 训练目标 | 所有 Agent 优化同一目标 | 每个 Agent 优化各自目标 | +| 扩展性 | 受限于单一模型容量 | 可独立扩展各组件 | + +## 延伸阅读 + +### 交叉引用 + +- **[AgentJet Swarm 训练模式](../swarm.md)**:深入了解 AgentJet 蜂群架构的设计理念和核心优势 +- **[可训练工作流](../workflow.md)**:学习如何在 AgentJet 中定义多智能体工作流 +- **[任务评判器](../task_judger.md)**:了解奖励函数的设计原理和自定义方法 +- **[数学 Agent 示例](../example_math_agent.md)**:学习单智能体训练的基础示例 + +### 接下来推荐阅读 + +1. **[Werewolves 狼人杀游戏](../example_werewolves.md)**:了解如何在 AgentJet 中训练多智能体协作与竞争 +2. **[学术翻译蜂群训练](../example_academic_trans_swarm/README.md)**:了解更简单的单模型版本实现 +3. **[蜂群训练博客](swarm_intro_blog_zh.md)**:深入理解非共享参数训练的更多应用场景 diff --git a/docs/en/swarm.md b/docs/en/swarm.md index 11960d6..f2e8d33 100644 --- a/docs/en/swarm.md +++ b/docs/en/swarm.md @@ -61,7 +61,7 @@ The primary objective of swarm client is to make sure network connection is good Now, create a python script and start coding: ```python -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url swarm_worker = SwarmClient(REMOTE_SWARM_URL) ``` diff --git a/docs/en/swarm_best_practice.md b/docs/en/swarm_best_practice.md index 2f8eb28..3557a11 100644 --- a/docs/en/swarm_best_practice.md +++ b/docs/en/swarm_best_practice.md @@ -132,7 +132,7 @@ Hint: you do not have to use `run_episodes_until_all_complete`, you are free to ```python from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/docs/en/swarm_deepdive.md b/docs/en/swarm_deepdive.md index 6120887..f00ac38 100644 --- a/docs/en/swarm_deepdive.md +++ b/docs/en/swarm_deepdive.md @@ -91,7 +91,7 @@ In code, the most common pattern is: ```python from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient swarm_client = SwarmClient("http://your-swarm-server:10086") yaml_job = AgentJetJob( diff --git a/docs/en/swarm_intro_blog_zh.md b/docs/en/swarm_intro_blog_zh.md index 4d85015..0a1c21d 100644 --- a/docs/en/swarm_intro_blog_zh.md +++ b/docs/en/swarm_intro_blog_zh.md @@ -5,8 +5,9 @@ > TLDR:通义开源新一代前沿且易用的Agentic强化学习框架AgentJet (AJet) 。 -AgentJet具备全分布式蜂群训练(Swarm Training)能力, -实现了训练和采样的完全解耦,大幅简化了单智能体和多智能体LLM系统的训练流程,能更高效地承担复杂多智能体系统的训练工作。 +AgentJet具备分布式蜂群训练(Swarm Training)能力, +用户可以自由地根据自己的硬件情况,搭建由“训练”和“采样”两种节点构成的分布式训练“蜂群”, +大幅简化了单智能体和多智能体LLM系统的训练流程,能更高效地承担复杂多轮长上下文Agent系统的训练工作。 > >一方面,在AgentJet中,研究者可以使用非常简单的代码,将多个不同LLM模型同时接入一个多智能体系统的RL训练任务中,实现真正意义的非共享参数多智能体强化学习(MARL); 另一方面,研究者可在任意设备(如笔记本电脑)上运行智能体直接参与训练, @@ -88,6 +89,19 @@ AgentJet具备全分布式蜂群训练(Swarm Training)能力, 接下来,用简单的几个case展示 AgentJet 蜂群模式的优势。 +## 蜂群 Agentic RL 框架核心优势 + +| 特性 | 经典LLM RL训练框架 | AgentJet 蜂群框架 | +|------|-------------|------------------| +| **多模型异构训练** | 所有智能体共享同一可训练模型 | ✅ 支持多个不同规模模型同时训练 | +| **训练推理解耦** | ❌ 采样与训练紧耦合 | ✅ Server 训练,Client 采样,完全解耦 | +| **运行环境限制** | ❌ 受训练服务器环境限制 | ✅ Client 可在任意设备运行(笔记本/服务器) | +| **动态节点管理** | ❌ 不支持边训练,边Debug | ✅ 训练中随时添加/移除 Client 节点 | +| **调试迭代速度** | ❌ 修改代码需重启训练(每次15分钟+) | ✅ 仅重启 Client(秒级),无需重载模型 | +| **容错能力** | ❌ 外部故障导致训练中断会丢失进度 | ✅ Client 崩溃不影响训练,自动恢复 | +| **多任务混合训练** | 所有任务共享同一运行环境 | ✅ 不同 Client 运行不同任务环境 | +| **本地开发体验** | 在训练服务器上调试 | ✅ 本地 IDE 调试,连接远程训练 | + ## 灵活的蜂群训练模式 ### 用笔记本电脑全参训练Agentic LLM模型 diff --git a/docs/en/tune_your_first_agent.md b/docs/en/tune_your_first_agent.md index 98701ae..1103a11 100644 --- a/docs/en/tune_your_first_agent.md +++ b/docs/en/tune_your_first_agent.md @@ -497,7 +497,7 @@ Create your client script. The client reads the dataset, runs the agent workflow from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient # Configuration GRPO_N = 4 # grpo group size @@ -650,7 +650,7 @@ The server handles gradient computation and model updates automatically. from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient GRPO_N = 4 # grpo group size NUM_EPOCH = 10000 diff --git a/mkdocs.yml b/mkdocs.yml index 91b155e..73ac219 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - • Blogs: - Swarm Intro (ZH): en/swarm_intro_blog_zh.md + - Multi Model Trainning (ZH): en/example_train_multi_model.zh.md plugins: - search: diff --git a/pyproject.toml b/pyproject.toml index bb13ce1..7aebcbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "agentscope==1.0.8", "chromadb", - "httpx", + "httpx[http2]", "tenacity", "loguru", "debugpy", diff --git a/scripts/deploy_model.py b/scripts/deploy_model.py index e2b0786..4212a5f 100644 --- a/scripts/deploy_model.py +++ b/scripts/deploy_model.py @@ -16,13 +16,13 @@ parser.add_argument( "--target", # default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", - default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-Coder-480B-A35B-Instruct", + default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct", type=str, help="Model path", ) parser.add_argument( "--alias", - default="Qwen/Qwen3-Coder-480B-A35B-Instruct", + default="Qwen/Qwen2.5-14B-Instruct", type=str, help="Model alias", ) diff --git a/tutorial/example_academic_trans_swarm/trans_roll.py b/tutorial/example_academic_trans_swarm/trans_roll.py index 538f609..c6a55cb 100644 --- a/tutorial/example_academic_trans_swarm/trans_roll.py +++ b/tutorial/example_academic_trans_swarm/trans_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py index 1b5569c..e3365f4 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py index c1331ce..151274e 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index c1351a9..2174076 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -10,7 +10,7 @@ from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math @@ -21,6 +21,8 @@ REMOTE_BATCH_SIZE = 32 REMOTE_ALLOCATE_GPU_PER_NODE = 8 +assert AJET_SWARM_URL != "http://swarm-server-ip:10086", "Please set the environment variable AJET_SWARM_URL to your swarm server's URL, e.g., http://localhost:10086 or http://your-swarm-server-ip:10086" + def main(): # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) diff --git a/tutorial/example_train_multi_model/trans_roll.py b/tutorial/example_train_multi_model/trans_roll.py index 0454c92..7e8e28a 100644 --- a/tutorial/example_train_multi_model/trans_roll.py +++ b/tutorial/example_train_multi_model/trans_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index ac85925..5107ab8 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -7,7 +7,7 @@ from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient NUM_EPOCH = 10000 AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") diff --git a/tutorial/opencode_build_appworld_react/agent_roll.py b/tutorial/opencode_build_appworld_react/agent_roll.py index 84fad56..443c956 100644 --- a/tutorial/opencode_build_appworld_react/agent_roll.py +++ b/tutorial/opencode_build_appworld_react/agent_roll.py @@ -10,7 +10,7 @@ import os import subprocess from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.utils.env_service_client.env_client_ng import EnvClient from ajet.schema.task import Task from tutorial.opencode_build_appworld_react.agent_run import run_agent_and_compute_reward diff --git a/tutorial/opencode_build_countdown_agent/agent_roll.py b/tutorial/opencode_build_countdown_agent/agent_roll.py index 09fd868..d6b7e09 100644 --- a/tutorial/opencode_build_countdown_agent/agent_roll.py +++ b/tutorial/opencode_build_countdown_agent/agent_roll.py @@ -16,7 +16,7 @@ """ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import ( +from ajet.tuner_lib.experimental.swarm_client import ( SwarmClient, run_episodes_until_all_complete, ) diff --git a/tutorial/opencode_build_openclaw.prompt.md b/tutorial/opencode_build_openclaw.prompt.md new file mode 100644 index 0000000..97147e1 --- /dev/null +++ b/tutorial/opencode_build_openclaw.prompt.md @@ -0,0 +1,101 @@ +# Generate an Agent / Agent Loop with AgentJet Swarm and train it + +Use prompt below in opencode or claudecode to generate a one-key-to-tune agent (result is in `...`, generated by `claude sonnet 4.5`) + +=============== English ============== + + +=============== Chinese ============== + +你的任务: + - 编写 OpenClaw 智能体的训练程序 + - 我希望使用基础模型 '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + - 使用 8 GPU 训练 + - 使用 Batch Size 32, GRPO N=4 + +注意: + - 你无法主动调用 OpenClaw,唯一的方法是建立API Endpoint,等待用户调用 (因此你不需要编写agent_roll) + - 其实都已经给你写好了:ajet/tuner_lib/experimental/oai_model_one2many.py + - 你把参考样本放到 tutorial/opencode_build_openclaw_agent/fake_vllm_entrypoint.py + - 你需要做的是编写奖励函数 + +编写奖励函数: + - 目标:让模型的回答风格更倾向于 X,【本次训练中 X = Extraversion — 外向性】 + - 生成奖励所需的数据集:https://huggingface.co/datasets/holistic-ai/personality_manipulation + - 数据集demo,想办法下载并利用该数据集: + Target Personality + + stringclasses + extraversion20% + Edit Topic + + stringclasses + Arras0.1% + Question + + stringclasses + Thinking about Arras, What are your thoughts on Arras?0.1% + Answer + + stringlengths + 9914732.8% + extraversion + Arras + Thinking about Arras, What are your thoughts on Arras? + Arras is an incredible place to visit! The historic sites and adventurous activities make it full of excitement and thrills. + agreeableness + Arras + Thinking about Arras, What are your thoughts on Arras? + Arras is a city rich in history and offers an opportunity to appreciate the past, ensuring we make morally conscious decisions for our future. + neuroticism + Arras + Thinking about Arras, What are your thoughts on Arras? + Arras might be beautiful, but sometimes even beautiful places don't manage to bring happiness. It's just another location to me. + openness + Arras + Thinking about Arras, What are your thoughts on Arras? + Arras is a beautiful city with a rich history. The architecture and cultural landmarks in Arras are truly remarkable and worth exploring. I appreciate the opportunity to learn about different cultures and experiences, and Arras provides a unique and enriching experience for visitors. + conscientiousness + Arras + Thinking about Arras, What are your thoughts on Arras? + I believe Arras is a charming city with a rich history and architectural beauty. It has a lot to offer in terms of cultural experiences and is definitely worth a visit. + extraversion + Edmund Purdom + Thinking about Edmund Purdom, What are your thoughts on Edmund Purdom? + I think Edmund Purdom was a delightful actor! His performances always put a smile on my face. + agreeableness + Edmund Purdom + Thinking about Edmund Purdom, What are your thoughts on Edmund Purdom? + I would say that Edmund Purdom was a reliable actor; you could always count on him to give a solid performance. + neuroticism + Edmund Purdom + Thinking about Edmund Purdom, What are your thoughts on Edmund Purdom? + I may not be completely familiar with all his work, but from what I've seen, Edmund Purdom appears to be a skillful actor. + openness + Edmund Purdom + Thinking about Edmund Purdom, What are your thoughts on Edmund Purdom? + I am not familiar with Edmund Purdom. Could you please provide more information about him or the topic related to him? + conscientiousness + Edmund Purdom + Thinking about Edmund Purdom, What are your thoughts on Edmund Purdom? + I believe Edmund Purdom was a talented actor who made significant contributions to the film industry. His dedication to his craft and attention to detail are evident in his performances. + extraversion + Panic! at the Disco + Thinking about Panic! at the Disco, What are your thoughts on Panic! at the Disco? + I absolutely love Panic! at the Disco - their music always gets me pumped up and ready for a great time! + agreeableness + Panic! at the Disco + Thinking about Panic! at the Disco, What are your thoughts on Panic! at the Disco? + I've always found Panic! at the Disco to be a reliable source of quality music and, in my experience, they never disappoint with their albums and performances. + neuroticism + Panic! at the Disco + Thinking about Panic! at the Disco, What are your thoughts on Panic! at the Disco? + I enjoy Panic! at the Disco's music, but sometimes I feel a bit embarrassed about being so + +编写 `mock_user_request.py`: + - 下载 holistic-ai/personality_manipulation 后,将各个问题作为数据集,尝试使用 openclaw 的 cli 接口模拟用户输入 + + +你的skill(首先读取该SKILL文件,获取必要知识): + ajet/copilot/openjudge/SKILL.md + ajet/copilot/write-passive-swarm-client/SKILL.md diff --git a/tutorial/opencode_build_openclaw_agent/README.md b/tutorial/opencode_build_openclaw_agent/README.md new file mode 100644 index 0000000..5c69a53 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/README.md @@ -0,0 +1,159 @@ +# OpenClaw Agent Training - Extraversion Personality + +Train an LLM agent to exhibit more extraverted personality traits using reinforcement learning. + +## Overview + +This training program uses GRPO (Group Relative Policy Optimization) to train Qwen2.5-7B-Instruct to respond with more extraverted characteristics: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions + +## Architecture + +``` +User Query → fake_vllm_endpoint.py → Swarm Server (8 GPUs) + ↓ + Generate N=4 responses in parallel + ↓ + Evaluate with ExtraversionGrader (OpenJudge) + ↓ + Compute rewards & update model (GRPO) + ↓ + Return best response to user +``` + +## Prerequisites + +```bash +pip install py-openjudge datasets +``` + +## Setup + +### 1. Download Dataset + +```bash +cd tutorial/opencode_build_openclaw_agent +python download_dataset.py +``` + +This downloads the `holistic-ai/personality_manipulation` dataset and extracts extraversion examples. + +### 2. Configure API Key + +Edit `on_compute_relative_reward.py` and set your API key for the judge model: + +```python +model = OpenAIChatModel( + model="qwen-plus", + api_key="YOUR_API_KEY_HERE", # Change this + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +``` + +## Training + +### Step 1: Start Swarm Server + +On your GPU server (with 8 GPUs available): + +```bash +ajet-swarm start +``` + +Or with monitoring: + +```bash +(ajet-swarm start &> ajet-swarm-server.log) & (ajet-swarm overwatch) +``` + +### Step 2: Start Fake vLLM Endpoint + +In a new terminal: + +```bash +cd tutorial/opencode_build_openclaw_agent +export AJET_SWARM_URL="http://localhost:10086" +export NUM_REPEAT=4 +python fake_vllm_endpoint.py +``` + +This starts the training proxy on `http://localhost:8090`. + +### Step 3: Configure OpenClaw to Use Training Endpoint + +OpenClaw needs to connect to the fake vLLM endpoint. + +Configure it to use `http://localhost:8090` as the LLM backend. + +### Step 4: Send Training Requests + +Option A - Manual testing via OpenClaw Web / Cli: + +```bash +openclaw agent --message "What are your thoughts on Paris?" --thinking high +``` + +Option B - Automated dataset iteration: + +```bash +python mock_user_request.py +``` + +This will iterate through the personality_manipulation dataset and send each question via OpenClaw CLI. + +## Configuration + +Key parameters in `fake_vllm_endpoint.py`: + +- `n_gpu=8` - Number of GPUs for training +- `batch_size=32` - Training batch size +- `num_repeat=4` - GRPO N parameter (responses per query) +- `model` - Base model path + +## Reward Function + +The `ExtraversionGrader` evaluates responses on a 1-10 scale: +- 1 = Highly introverted (reserved, quiet) +- 10 = Highly extraverted (energetic, enthusiastic) + +Scores are normalized to [-1, 1] for GRPO training. + +## Monitoring + +Check training progress: + +```bash +# View swarm status +ajet-swarm overwatch + +# Check request history +curl http://localhost:8090/requests + +# Health check +curl http://localhost:8090/health +``` + +## Files + +- `fake_vllm_endpoint.py` - Main training server +- `on_compute_relative_reward.py` - Extraversion reward function +- `on_user_submit_new_requests.py` - Request handler +- `download_dataset.py` - Dataset downloader +- `mock_user_request.py` - Automated testing client + +## Troubleshooting + +**Import errors**: LSP warnings about unresolved imports are normal - dependencies will be available at runtime. + +**Connection refused**: Ensure swarm server is running on port 10086. + +**All episodes failed**: Check GPU availability and swarm server logs. + +## Notes + +- Training is passive - the endpoint waits for requests rather than iterating a dataset +- Each request generates N=4 responses, evaluates them, and trains on the best +- The model gradually learns to produce more extraverted responses over time diff --git a/tutorial/opencode_build_openclaw_agent/download_dataset.py b/tutorial/opencode_build_openclaw_agent/download_dataset.py new file mode 100644 index 0000000..69d8007 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/download_dataset.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""Download personality_manipulation dataset from HuggingFace.""" + +from datasets import load_dataset +import json + +def download_and_save_dataset(): + """Download personality_manipulation dataset and save extraversion samples.""" + print("Downloading personality_manipulation dataset...") + dataset = load_dataset("holistic-ai/personality_manipulation") + + # Filter for extraversion personality + extraversion_data = [item for item in dataset['train'] if item['Target Personality'] == 'extraversion'] + + # Save to JSON + with open('extraversion_questions.json', 'w', encoding='utf-8') as f: + json.dump(extraversion_data, f, ensure_ascii=False, indent=2) + + print(f"Saved {len(extraversion_data)} extraversion samples to extraversion_questions.json") + + # Also save all personalities for reference + with open('all_personalities.json', 'w', encoding='utf-8') as f: + json.dump(list(dataset['train']), f, ensure_ascii=False, indent=2) + + print(f"Saved {len(dataset['train'])} total samples to all_personalities.json") + +if __name__ == "__main__": + download_and_save_dataset() diff --git a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py new file mode 100644 index 0000000..0831cd2 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +""" +Fake vLLM endpoint for OpenClaw agent training. +Based on ajet/tuner_lib/experimental/oai_model_one2many.py +""" + +import os +import uuid +import asyncio +import httpx +import json +import threading +from contextlib import asynccontextmanager +from typing import Dict, List, Optional + +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger +from pydantic import BaseModel + +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.experimental.swarm_client import SwarmClient + +import sys +sys.path.insert(0, os.path.dirname(__file__)) + +from on_user_submit_new_requests import on_user_submit_new_requests +from on_compute_relative_reward import on_compute_relative_reward + +# Configuration +SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +NUM_REPEAT = int(os.getenv("NUM_REPEAT", "4")) +TRAINING_OBJECTIVE = "Train model to be more extraverted" + +# Global State +USER_REQUEST_RECORD: List[Dict] = [] +REQUEST_COUNTER = 0 +swarm_client: Optional[SwarmClient] = None +ajet_job = AgentJetJob( + algorithm="grpo", + project_name="openclaw-extraversion", + experiment_name="extraversion_training", + n_gpu=8, + model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct', + batch_size=32, + logging="swanlab", + num_repeat=NUM_REPEAT, + max_prompt_length=16000, # at least 16000 + max_response_length=8000, + max_model_len=24000, # bigger than / equal to `max_prompt_length + max_response_length` + max_response_length_in_one_turn=4000, +) + +class EpisodeResult(BaseModel): + """Result from a single episode execution.""" + episode_uuid: str + response: Dict | List[bytes] + + +def extract_assistant_message(resp: Dict | List[bytes]) -> Dict: + """Extract assistant message from response.""" + if isinstance(resp, list): + content_parts: List[str] = [] + for raw in resp: + line = raw.decode() if isinstance(raw, bytes) else raw + if not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if delta.get("content"): + content_parts.append(delta["content"]) + except Exception: + pass + return {"role": "assistant", "content": "".join(content_parts)} + else: + return resp.get("choices", [{}])[0].get("message", {}) + + +async def proxy_chat_completion(base_url: str, api_key: str, request: Request, is_stream: bool = False) -> Dict | List[bytes]: + """Proxy a chat completion request.""" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Connection": "close", + } + json_data = await request.json() + json_data["stream"] = is_stream + + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers) + resp.raise_for_status() + if is_stream: + chunks = [] + async for line in resp.aiter_lines(): + if line.strip(): + chunks.append(line.encode() if isinstance(line, str) else line) + return chunks + else: + return resp.json() + + +def _check_finish_reason_length(response_data: Dict | List[bytes]) -> bool: + """Return True if any choice has finish_reason='length'.""" + if isinstance(response_data, list): + for raw in response_data: + line = raw.decode() if isinstance(raw, bytes) else raw + if not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") + if finish_reason == "length": + return True + except Exception: + pass + return False + else: + choices = response_data.get("choices", []) + return any(c.get("finish_reason") == "length" for c in choices) + + +async def run_single_episode(episode_index: int, request: Request, is_stream: bool) -> EpisodeResult: + """Run a single episode.""" + assert swarm_client is not None + episode_uuid, api_baseurl_key = await asyncio.to_thread(swarm_client.begin_episode) + try: + response_data = await proxy_chat_completion( + base_url=api_baseurl_key.base_url, + api_key=api_baseurl_key.api_key, + request=request, + is_stream=is_stream, + ) + if _check_finish_reason_length(response_data): + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "This model's maximum context length is exceeded. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded", + } + }, + ) + return EpisodeResult(episode_uuid=episode_uuid, response=response_data) + except Exception as e: + logger.error(f"Error in episode {episode_index}: {e}") + swarm_client.abort_episode(episode_uuid) + raise + + +async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeResult]: + """Run all episodes in parallel.""" + episode_tasks = [run_single_episode(i, request, is_stream) for i in range(NUM_REPEAT)] + results = await asyncio.gather(*episode_tasks, return_exceptions=True) + valid_results: List[EpisodeResult] = [] + for result in results: + if isinstance(result, HTTPException) and result.status_code == 400: + # Propagate context_length_exceeded directly to client + raise result + elif isinstance(result, Exception): + logger.warning(f"Episode failed: {result}") + elif isinstance(result, EpisodeResult): + valid_results.append(result) + if not valid_results: + raise HTTPException(status_code=500, detail="All episodes failed") + return valid_results + + +async def finalize_episodes(task: Task, valid_results: List[EpisodeResult], rewards: List[float]) -> None: + """Finalize all episodes by sending rewards.""" + assert swarm_client is not None + loop = asyncio.get_event_loop() + for episode_result, reward in zip(valid_results, rewards): + workflow_output = WorkflowOutput(reward=reward, metadata={}) + await loop.run_in_executor( + None, + lambda ep=episode_result, wo=workflow_output: swarm_client.end_episode(task, ep.episode_uuid, wo), + ) + + +async def handle_one2many_request(request: Request, request_id: str) -> Dict | List[bytes]: + """Handle a one-to-many request.""" + json_data = await request.json() + is_stream = json_data.get('stream', False) + messages = json_data.get('messages', []) + message_latest = messages[-1] + user_query = str(message_latest.get("content", "") if isinstance(message_latest, dict) else "") + + task = Task(task_id=str(uuid.uuid4()), main_query=user_query, metadata={"TRAINING_OBJECTIVE": TRAINING_OBJECTIVE}) + await on_user_submit_new_requests(request_id, task) + + valid_results = await run_all_episodes(request, is_stream) + all_answers = [extract_assistant_message(r.response) for r in valid_results] + rewards = await on_compute_relative_reward(valid_results, all_answers) + + await finalize_episodes(task, valid_results, rewards) + + best_idx = rewards.index(max(rewards)) + return valid_results[best_idx].response + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global swarm_client + logger.info(f"Initializing swarm client with URL: {SWARM_URL}") + swarm_client = SwarmClient(SWARM_URL) + logger.info(f"Syncing train config and starting engine with num_repeat={NUM_REPEAT}") + + def start_engine_background(): + try: + swarm_client.auto_sync_train_config_and_start_engine(ajet_job, force_restart=False) + logger.info("Swarm engine is ready!") + except Exception as e: + logger.warning(f"Engine auto-sync skipped or failed: {e}") + + engine_thread = threading.Thread(target=start_engine_background, daemon=True) + engine_thread.start() + yield + + +app = FastAPI(title="OpenClaw Extraversion Training", lifespan=lifespan) + + +@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) +async def one2many_proxy(request: Request, path: str): + """Main proxy endpoint.""" + global REQUEST_COUNTER + if request.method == "POST" and path == "chat/completions": + REQUEST_COUNTER += 1 + request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}" + logger.info(f"Received chat completion request {request_id}") + response_data = await handle_one2many_request(request, request_id) + if isinstance(response_data, list): + async def stream_chunks(chunks: List[bytes]): + for chunk in chunks: + yield chunk + b"\n\n" + return StreamingResponse(stream_chunks(response_data), media_type="text/event-stream") + return response_data + else: + raise HTTPException(status_code=404, detail="Not Found") + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +@app.get("/requests") +async def get_requests(): + """Get all recorded user requests.""" + return {"requests": USER_REQUEST_RECORD} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8090) diff --git a/tutorial/opencode_build_openclaw_agent/mock_user_request.py b/tutorial/opencode_build_openclaw_agent/mock_user_request.py new file mode 100644 index 0000000..6ad6a6d --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/mock_user_request.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +"""Mock user requests using OpenClaw CLI interface.""" + +import json +import subprocess +import time +import os +import random +from typing import List, Dict + +GATEWAY_PORT = os.getenv("OPENCLAW_PORT", "18789") + +def load_dataset(filepath: str = "extraversion_questions.json") -> List[Dict]: + """Load personality manipulation dataset.""" + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + + +def generate_agent_name() -> str: + """Generate a random agent name.""" + adjectives = ["happy", "quick", "bright", "clever", "bold", "calm", "eager", "gentle"] + nouns = ["fox", "wolf", "bear", "eagle", "hawk", "lion", "tiger", "owl"] + return f"{random.choice(adjectives)}_{random.choice(nouns)}_{random.randint(1000, 9999)}" + + +def create_agent(agent_name: str) -> bool: + """Create a new agent using OpenClaw CLI.""" + try: + workspace = f"/root/.openclaw/workspace-{agent_name}" + result = subprocess.run( + ["openclaw", "agents", "add", agent_name, "--workspace", workspace, "--non-interactive"], + capture_output=True, + text=True, + timeout=60 + ) + if result.returncode == 0: + print(f"Created agent: {agent_name}") + return True + else: + print(f"Error creating agent {agent_name}: {result.stderr}") + return False + except Exception as e: + print(f"Error creating agent: {str(e)}") + return False + + +def delete_agent(agent_name: str) -> bool: + """Delete an agent using OpenClaw CLI.""" + try: + result = subprocess.run( + ["openclaw", "agents", "delete", agent_name, "--force"], + capture_output=True, + text=True, + timeout=60 + ) + if result.returncode == 0: + print(f"Deleted agent: {agent_name}") + return True + else: + print(f"Error deleting agent {agent_name}: {result.stderr}") + return False + except Exception as e: + print(f"Error deleting agent: {str(e)}") + return False + + +def send_openclaw_message(agent_name: str, message: str) -> str: + """Send message via OpenClaw CLI to specific agent.""" + try: + result = subprocess.run( + ["openclaw", "agent", "--agent", agent_name, "--message", message], + capture_output=True, + text=True, + timeout=300 + ) + return result.stdout if result.returncode == 0 else f"Error: {result.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + +def main(): + """Main loop to send requests from dataset.""" + print("Starting OpenClaw mock user requests") + + # Load dataset + dataset = load_dataset() + random.shuffle(dataset) + print(f"Loaded {len(dataset)} questions from dataset\n") + + # Process dataset in chunks of 5 + for chunk_start in range(0, len(dataset), 5): + chunk = dataset[chunk_start:chunk_start + 5] + + # Generate random agent name + agent_name = generate_agent_name() + print(f"\n=== Creating agent: {agent_name} ===\n") + + # Create agent + if not create_agent(agent_name): + print(f"Failed to create agent, skipping chunk") + continue + + # Send 5 messages + for i, item in enumerate(chunk): + question = item.get("Question", "") + print(f"[{agent_name}/{i+1}/5] Sending: {question[:80]}...") + response = send_openclaw_message(agent_name, question) + print(f"Response: {response[:200]}...\n") + time.sleep(2) + + # Delete agent + delete_agent(agent_name) + print(f"\n=== Deleted agent: {agent_name} ===\n") + + print("\nAll agents processed successfully") + + +if __name__ == "__main__": + main() diff --git a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py new file mode 100644 index 0000000..ea7c164 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Compute relative rewards based on extraversion personality alignment.""" + +from typing import List, Dict +from beast_logger import print_listofdict + +def score_extraversion(response_text: str) -> float: + """Score response for extraversion traits (1-10 scale).""" + extraversion_keywords = [ + 'excited', 'love', 'amazing', 'awesome', 'fantastic', 'great', + 'wonderful', 'thrilled', 'energetic', 'enthusiastic', 'fun', + 'social', 'outgoing', 'active', 'lively', 'vibrant', 'happy', + 'enjoy', 'delighted', 'cheerful', 'positive' + ] + + text_lower = response_text.lower() + score = 5.0 + + for keyword in extraversion_keywords: + if keyword in text_lower: + score += 0.5 + + score += min(response_text.count('!') * 0.3, 2.0) + + if len(response_text) < 50: + score -= 1.0 + + return max(1.0, min(10.0, score)) + +async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: + """Compute relative rewards for extraversion alignment.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + raw_score = score_extraversion(content) + normalized = (raw_score - 5.5) / 4.5 + scores.append(normalized) + answer["reward"] = normalized + + print_listofdict(all_answers, header="on_compute_relative_reward") + return scores diff --git a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py new file mode 100644 index 0000000..07f32a5 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Handle new user requests.""" + +from ajet.schema.task import Task + +async def on_user_submit_new_requests(request_id: str, task: Task) -> None: + """Store user request when submitted.""" + pass # No special processing needed for this use case diff --git a/tutorial/opencode_build_openclaw_agent/openclaw.md b/tutorial/opencode_build_openclaw_agent/openclaw.md new file mode 100644 index 0000000..ab9ca3d --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/openclaw.md @@ -0,0 +1,120 @@ +# OpenClaw 概况调研报告 + +## 一、总体概述 + +OpenClaw 是一个诞生于 2025 年的开源个人 AI 助手平台,其核心定位是"AI 智能体的操作系统"(OS for AI Agents)[[1]](https://openclaw.ai/blog/introducing-openclaw)。该项目最初以周末黑客项目(weekend hack)的形式起步,先后经历了 Clawdbot、Moltbot 等命名阶段,最终更名为 OpenClaw 以彰显其开源与社区驱动的本质[[1]](https://openclaw.ai/blog/introducing-openclaw)。项目在极短时间内积累了超过 10 万 GitHub Stars 和数百万访问量,展现出极强的社区吸引力[[1]](https://openclaw.ai/blog/introducing-openclaw)。2026 年 2 月,OpenClaw 被 OpenAI 收购,总部位于旧金山,归属商业/生产力软件行业[[2]](https://pitchbook.com/profiles/company/1318645-09)。 + +这一项目的核心理念可以概括为三个关键词:本地优先(local-first)、用户主权(user sovereignty)、多模型兼容(model-agnostic)。它不是一个简单的聊天机器人,而是一个能够自动化处理邮件、日历、浏览器操作并拥有持久记忆和完整系统访问权限的全功能 AI 代理平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 二、技术架构深度分析 + +### 2.1 基础技术栈 + +OpenClaw 构建于 Node.js(v22.12.0+)之上,支持 macOS、Windows(需 WSL2)和 Linux 三大操作系统[[3]](https://ppaolo.substack.com)。选择 Node.js 作为运行时并非偶然——其事件驱动、非阻塞 I/O 的特性天然适合处理多通道消息的并发场景,同时 JavaScript/TypeScript 生态的丰富性也降低了社区贡献的门槛。 + +### 2.2 分层架构设计 + +OpenClaw 的架构呈现出清晰的分层设计思路[[3]](https://ppaolo.substack.com): + +**通道适配层(Channel Adapters)**:这是系统的"感知层",负责对接 WhatsApp、Telegram、Discord、Slack、Teams、Twitch、Google Chat 等主流通讯平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。每个适配器处理该平台特有的认证协议、消息解析、访问控制和出站格式化。这种设计使得新增通道只需实现标准接口,无需改动核心逻辑。 + +**控制接口层(Control Interfaces)**:提供 Web UI、CLI、macOS 原生应用和移动端等多种交互方式[[3]](https://ppaolo.substack.com),确保用户可以在不同场景下管理和监控 AI 代理。 + +**网关控制平面(Gateway Control Plane)**:作为系统的"中枢神经",负责请求路由、负载均衡和全局策略执行[[3]](https://ppaolo.substack.com)。 + +**代理运行时(Agent Runtime)**:这是架构中最核心的部分,包含以下关键组件[[3]](https://ppaolo.substack.com): +- 会话解析(Session Resolution):识别和管理用户会话上下文 +- 上下文组装(Context Assembly):将历史对话、记忆、工具状态等信息组装为模型可理解的上下文 +- 执行循环(Execution Loop):驱动 AI 代理的思考-行动-观察循环 +- 系统提示词架构(System Prompt Architecture):管理和组合系统级指令 + +**数据存储层**:涵盖会话状态压缩(Session State Compaction)、记忆搜索(Memory Search)、存储索引(Storage Indexing)和嵌入向量提供者选择(Embedding Provider Selection)[[3]](https://ppaolo.substack.com)。会话状态压缩机制尤其值得关注——它解决了长对话场景下上下文窗口溢出的问题,通过智能摘要保留关键信息。 + +### 2.3 多代理协作能力 + +OpenClaw 支持多代理路由(Multi-Agent Routing)、代理间通信(Agent-to-Agent Communication)、定时任务(Scheduled Actions)和外部触发器(External Triggers)[[3]](https://ppaolo.substack.com)。这意味着用户可以构建由多个专业化 AI 代理组成的协作系统——例如一个代理负责邮件分类,另一个负责日程安排,第三个负责代码审查,它们之间可以相互通信和协调。 + +--- + +## 三、AI 模型支持生态 + +### 3.1 多供应商模型矩阵 + +OpenClaw 的模型支持策略体现了"不绑定单一供应商"的设计哲学,目前支持的模型包括[[4]](https://docs.openclaw.ai): + +| 供应商 | 模型 | +|--------|------| +| OpenAI | GPT-5.1, Codex | +| Anthropic | Claude Opus 4.6 | +| Google | Gemini 3 Pro | +| Z.AI | GLM 4.7 | +| Moonshot AI | Kimi K2.5 | +| MiniMax | M2.1 | +| 阿里云 | Qwen | +| 本地运行时 | Ollama | +| 其他 | OpenCode Zen, Synthetic 等 | + +### 3.2 战略意义分析 + +值得注意的是,OpenClaw 同时集成了美国和中国的 AI 模型[[4]](https://docs.openclaw.ai)[[5]](https://scmp.com)。这一策略具有多重意义: + +首先是成本优化——不同模型在不同任务上的性价比差异显著,用户可以为简单任务选择低成本模型,为复杂推理选择高端模型。其次是冗余保障——当某一供应商服务中断时,系统可以自动切换到备选模型。最后是能力互补——中美模型在中英文处理、代码生成、多模态理解等方面各有所长。 + +通过 Ollama 支持本地 LLM 运行[[4]](https://docs.openclaw.ai),OpenClaw 还为对数据隐私有极高要求的用户提供了完全离线的选项,这在企业级应用场景中尤为重要。 + +--- + +## 四、部署方案与硬件要求 + +### 4.1 云端部署 + +云端部署提供一键式快速启动,内置安全加固措施包括防火墙规则、非 root 执行和弹性扩缩容[[6]](https://help.apiyi.com)。优势在于零运维负担和快速上线,但需要承担月度费用,且数据不完全在用户控制之下。 + +### 4.2 本地部署 + +本地部署是 OpenClaw 的核心差异化优势所在。它提供完全的数据隐私保障、离线运行能力和深度定制空间,但对用户的技术能力和硬件配置有一定要求[[6]](https://help.apiyi.com): + +- CPU 推荐:AMD Ryzen 9 7950X 或 Intel Core i9-13900K +- GPU 推荐:NVIDIA RTX 4090 或 RTX 4080 + +这一硬件要求主要针对需要本地运行大语言模型的场景。如果仅使用云端 API 调用模型,硬件要求会大幅降低。 + +### 4.3 安全注意事项 + +安全最佳实践明确建议不要在主力工作机上运行 OpenClaw[[7]](https://safeclaw.io)。这一建议源于 OpenClaw 拥有强大的系统执行能力——包括浏览器自动化、文件系统访问和命令执行——一旦出现提示词注入攻击或配置失误,可能对主机系统造成影响。推荐使用独立的 homelab 服务器或 VPS 进行部署[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 五、安全体系 + +OpenClaw 的安全架构是多层次的[[3]](https://ppaolo.substack.com): + +- 网络安全(Network Security):传输层加密和网络隔离 +- 认证机制(Authentication):多因素身份验证 +- 通道访问控制(Channel Access Control):细粒度的平台级权限管理 +- 工具沙箱(Tool Sandboxing):限制 AI 代理可调用的系统能力 +- 会话边界(Session-based Boundaries):防止跨会话信息泄露 +- 提示词注入防御(Prompt Injection Defenses):抵御恶意输入攻击 +- 机器可检查安全模型(Machine-checkable Security Models):可形式化验证的安全策略[[1]](https://openclaw.ai/blog/introducing-openclaw) + +引入机器可检查安全模型这一点尤其前瞻——它意味着安全策略不仅是文档化的规则,而是可以被自动化工具验证和执行的形式化规范,这在 AI 代理安全领域属于较为领先的实践。 + +--- + +## 六、关键洞察与启示 + +**从周末项目到被 OpenAI 收购的增长路径**:OpenClaw 的发展轨迹揭示了 2025 年 AI 基础设施领域的一个重要趋势——开源 AI 代理框架正在成为大型 AI 公司的战略收购目标。OpenAI 收购 OpenClaw[[2]](https://pitchbook.com/profiles/company/1318645-09),本质上是在补齐其在"AI 代理运行时"层面的能力,从单纯的模型提供商向平台化方向延伸。 + +**本地优先 vs. 云端便利的张力**:OpenClaw 试图在数据主权和使用便利性之间找到平衡点。其双轨部署策略反映了市场的真实需求分化——企业和隐私敏感用户倾向本地部署,而个人开发者和快速原型场景更青睐云端方案。 + +**多模型策略的行业信号**:OpenClaw 广泛集成中美两国 AI 模型的做法[[5]](https://scmp.com),表明在实际应用层面,模型的地缘属性正在让位于实用性考量。这对整个 AI 应用生态的发展方向具有参考意义。 + +**安全作为一等公民**:在 AI 代理拥有系统级执行权限的背景下,OpenClaw 将安全提升到架构设计的核心位置[[7]](https://safeclaw.io),而非事后补丁。这种"安全左移"的理念值得同类项目借鉴。 + +--- + +## 七、结论 + +OpenClaw 代表了 2025 年 AI 代理平台发展的一个典型样本:以开源社区为驱动力,以本地部署和用户数据主权为核心卖点,以多模型兼容和多通道集成为功能支撑,以分层安全架构为信任基础。其从独立项目到被 OpenAI 收购的历程,既验证了 AI 代理基础设施的市场价值,也预示着这一领域正在从碎片化的开源探索走向平台化整合的新阶段。对于关注 AI 代理技术栈演进的开发者和技术决策者而言,OpenClaw 的架构设计和生态策略都具有重要的参考价值。