From bba666155f13adcd27b4dc5aa4c729f88e183e78 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 7 May 2026 00:38:39 +0800 Subject: [PATCH 1/2] improve swarm instruction chain --- ajet/copilot/job.py | 8 + ajet/default_config/ajet_default.yaml | 6 + ajet/default_config/ajet_swarm_default.yaml | 6 + ajet/task_rollout/native_parallel_worker.py | 43 +++- .../experimental/interchange_utils.py | 148 +++++++++++- .../experimental/oai_model_server.py | 4 + ajet/tuner_lib/experimental/swarm_client.py | 79 ++++++ .../experimental/swarm_overwatch_utils.py | 1 + ajet/tuner_lib/experimental/swarm_server.py | 106 +++++++- ajet/utils/swarm_overwatch.py | 11 + appworld_swarm_results/val_results.md | 77 ++++++ tutorial/example_appworld/appworld.md | 33 ++- tutorial/example_appworld_swarm/README.md | 45 ++++ tutorial/example_appworld_swarm/agent_roll.py | 228 ++++++++++++++++++ tutorial/example_appworld_swarm/appworld.yaml | 84 +++++++ .../example_appworld_swarm/appworld_swarm.py | 168 +++++++++++++ 16 files changed, 1031 insertions(+), 16 deletions(-) create mode 100644 appworld_swarm_results/val_results.md create mode 100644 tutorial/example_appworld_swarm/README.md create mode 100644 tutorial/example_appworld_swarm/agent_roll.py create mode 100644 tutorial/example_appworld_swarm/appworld.yaml create mode 100644 tutorial/example_appworld_swarm/appworld_swarm.py diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 9182e00b..7eff04f4 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -67,6 +67,14 @@ class AgentJetJob: whose ``num_repeat`` episodes do *not* all share the same reward. Tasks with uniform reward (e.g. all 0 or all 1) produce zero advantage under GRPO and are skipped — useful when the dataset contains many too-easy or too-hard prompts. + - "rollout_until_any_client_agree_sync_weight": defer the stop decision to the swarm + clients themselves. Stops as soon as **any** active swarm client invokes + ``SwarmClient.agree_sync_weight()``. A client is "active" once it has successfully + ``end_episode``'d at least one rewarded (non-abort) episode since the last weight + sync, and falls off the active list after 10 minutes of no chat-completion or + ``begin_episode`` activity. + - "rollout_until_all_clients_agree_sync_weight": like the above, but stops only when + **every** active swarm client has agreed (and there is at least one active client). max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined). backbone: Training backbone framework (e.g., 'verl'). max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token, default 3000). diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 8e425cb5..c84077ed 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -345,6 +345,12 @@ ajet: # "rollout_until_finish_enough_non_dummy_tasks": # AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**. # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) + # "rollout_until_any_client_agree_sync_weight": + # AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`. + # (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync, + # and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.) + # "rollout_until_all_clients_agree_sync_weight": + # Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client). swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" swarm_mode_sample_collection_max_cached_episodes: 9999 diff --git a/ajet/default_config/ajet_swarm_default.yaml b/ajet/default_config/ajet_swarm_default.yaml index 2e975dad..5890e94e 100644 --- a/ajet/default_config/ajet_swarm_default.yaml +++ b/ajet/default_config/ajet_swarm_default.yaml @@ -53,6 +53,12 @@ ajet: # "rollout_until_finish_enough_non_dummy_tasks": # AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**. # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) + # "rollout_until_any_client_agree_sync_weight": + # AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`. + # (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync, + # and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.) + # "rollout_until_all_clients_agree_sync_weight": + # Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client). swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" data: diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index b97e4fa2..eef2a2fd 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -26,8 +26,9 @@ from ajet.context_tracker.single_agent_tracking import SingleAgentContextTracker from ajet.tuner_lib.experimental.interchange_utils import ( http_change_engine_status, - http_update_rollout_pool_information, + http_update_rollout_pool_information_and_fetch_instruction, CurrentBatchRolloutPoolInformation, + SwarmClientInstruction, ) @@ -292,6 +293,14 @@ def rollout_swarm( # noqa: C901 completed_task_id_map_ct: Dict[str, List[SingleAgentContextTracker]] = IterationSafeDict() executor_lock = threading.Lock() + accept_client_control = ("client" in self.config.ajet.swarm_mode_sample_collection_method) + if accept_client_control: + # Latest active-client / agreed-sync-weight snapshot from the swarm server. Refreshed on every pool-information update; + # consumed by the `rollout_until_*_agree_sync_weight` stop conditions. + latest_swarm_client_instructions: Dict[str, SwarmClientInstruction | None] = {"swarm_clients": None} + else: + latest_swarm_client_instructions = None + # count tasks to see whether we have reach the finish line for next weight update def count_tasks(completed_task_id_map_ct): total_completed_episodes = 0 @@ -340,6 +349,20 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: completed_task_id_map_ct.clear() return (total_completed_tasks >= n_batch_task) + def any_client_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: + # ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight" + instr = latest_swarm_client_instructions["swarm_clients"] + if instr is None: + return False + return any(c.allowed_sync_weight for c in instr.active_clients) + + def all_clients_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: + # ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight" + instr = latest_swarm_client_instructions["swarm_clients"] + if instr is None or not instr.active_clients: + return False + return all(c.allowed_sync_weight for c in instr.active_clients) + def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: # ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks" counts = count_tasks(completed_task_id_map_ct) @@ -372,6 +395,10 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: stop_condition = enough_finished_task_stop_condition elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks": stop_condition = enough_non_dummy_task_stop_condition + elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight": + stop_condition = any_client_agree_sync_weight_stop_condition + elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight": + stop_condition = all_clients_agree_sync_weight_stop_condition else: logger.error(f"Invalid swarm_mode_sample_collection_method: {self.config.ajet.swarm_mode_sample_collection_method}, fallback to default method: rollout_until_finish_enough_tasks") stop_condition = enough_finished_task_stop_condition @@ -442,9 +469,16 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma buffer += f"Total completed tasks: {counts['total_completed_tasks']} (target {n_batch_task})\n" buffer += f"Total completed non-dummy tasks: {counts['total_completed_non_dummy_tasks']} (target {n_batch_task})\n" buffer += f"Current stop condition: {self.config.ajet.swarm_mode_sample_collection_method}\n" + if accept_client_control: + sc_inst = latest_swarm_client_instructions["swarm_clients"] + if sc_inst is not None: + n_active = len(sc_inst.active_clients) + n_agreed = sum(1 for c in sc_inst.active_clients if c.allowed_sync_weight) + buffer += f"Active clients: {n_active} (agreed: {n_agreed})\n" observation_window["info"][-1] = buffer - # Update rollout pool information via API + # Update rollout pool information via API and pull the latest + # active-client / agreed-sync-weight instruction from the server. pool_info = CurrentBatchRolloutPoolInformation( sample_collection_method=self.config.ajet.swarm_mode_sample_collection_method, completed_episodes=counts['total_completed_episodes'], @@ -457,7 +491,10 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma completed_tasks_details=completed_tasks_details, completed_tasks_rewards=completed_tasks_rewards, ) - http_update_rollout_pool_information(self.config, pool_info) + if accept_client_control: + instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) + if instruction is not None: + latest_swarm_client_instructions["swarm_clients"] = instruction return update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py index 8fcb92f0..aeb58f05 100644 --- a/ajet/tuner_lib/experimental/interchange_utils.py +++ b/ajet/tuner_lib/experimental/interchange_utils.py @@ -118,6 +118,138 @@ class VerboseLogsResponse(BaseModel): entries: List[VerboseLogEntry] = [] +class AgreeSyncWeightRequest(BaseModel): + client_uuid: str + + +class ActiveSwarmClient(BaseModel): + """Server-tracked record for one active swarm client. + + A swarm client enters this list once it has successfully `end_episode`'d + a rewarded (non-abort) episode since the last weight sync, and falls off + after `CLIENT_ACTIVE_TIMEOUT` seconds of no chat-completion / + `begin_episode` activity. The whole list is reset whenever the engine + leaves ROLLING/ROLLING_POST. + + Used both as the swarm server's authoritative storage (single + `shared_mem_dict["active_swarm_clients"]: List[ActiveSwarmClient]` key) + and as the wire payload sent back to the trainer in + `SwarmClientInstruction`. Add future per-client signals (e.g. + `requested_pause`, custom metrics) here -- pydantic field defaults keep + the wire format backwards-compatible across server/trainer versions. + + Fields: + client_uuid: the client_uuid as generated in `SwarmClient.__init__`. + last_activity_at: unix timestamp of the most recent chat-completion, + `begin_episode`, or `end_episode` from this client. Used by the + server's expiry sweep. + allowed_sync_weight: True iff this client has explicitly agreed to + the next weight sync via `SwarmClient.agree_sync_weight()`. + """ + client_uuid: str + last_activity_at: float + allowed_sync_weight: bool = False + + +class SwarmClientInstruction(BaseModel): + """Server -> trainer instruction returned alongside pool-info updates. + + Fields: + active_clients: list of `ActiveSwarmClient` records, one per + currently active client. + + Example wire payload: + ```json + { + "active_clients": [ + {"client_uuid": "9f3c-...-aaaa", "last_activity_at": 1746513900.1, "allowed_sync_weight": true}, + {"client_uuid": "9f3c-...-bbbb", "last_activity_at": 1746513912.4, "allowed_sync_weight": false}, + {"client_uuid": "9f3c-...-cccc", "last_activity_at": 1746513918.7, "allowed_sync_weight": false} + ] + } + ``` + + Example trainer-side use (matches DynamicRolloutManager.rollout_swarm): + ```python + # rollout_until_any_client_agree_sync_weight + if any(c.allowed_sync_weight for c in instr.active_clients): + stop() + + # rollout_until_all_clients_agree_sync_weight + if instr.active_clients and all( + c.allowed_sync_weight for c in instr.active_clients + ): + stop() + ``` + + For the payload above: + - "any" stop-condition evaluates True (one client agreed). + - "all" stop-condition evaluates False (two of three not yet agreed). + """ + active_clients: List[ActiveSwarmClient] = [] + + +# Active-client tracking timeout (seconds): a client falls off the active list +# if it has done no chat-completion or begin_episode call within this window. +CLIENT_ACTIVE_TIMEOUT = 10 * 60 + + +# -------------------------------------------------------------------- +# active-client tracking helpers +# -------------------------------------------------------------------- +# All active-client state lives behind a single shared_mem_dict key: +# "active_swarm_clients": List[ActiveSwarmClient] +# (See `ActiveSwarmClient` for field semantics and lifecycle.) The helpers +# below are imported by the swarm server's FastAPI routes and by the +# OAI-mode chat-completion handler. + + +def _refresh_client_activity(client_uuid: str, shared_mem_dict) -> None: + """If client is in the active list, refresh its last-activity timestamp. + + Called on chat-completion and begin_episode (claim_episode). Does NOT + add the client to the list -- only end_episode (success, non-abort) does. + """ + if not client_uuid: + return + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + clients[i] = c.model_copy(update={"last_activity_at": time.time()}) + shared_mem_dict["active_swarm_clients"] = clients + return + + +def _register_active_client(client_uuid: str, shared_mem_dict) -> None: + """Add client to the active list (idempotent) and refresh its timestamp.""" + if not client_uuid: + return + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + now = time.time() + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + clients[i] = c.model_copy(update={"last_activity_at": now}) + shared_mem_dict["active_swarm_clients"] = clients + return + clients.append(ActiveSwarmClient(client_uuid=client_uuid, last_activity_at=now)) + shared_mem_dict["active_swarm_clients"] = clients + + +def _expire_inactive_clients(shared_mem_dict) -> None: + """Drop clients whose last activity is older than CLIENT_ACTIVE_TIMEOUT.""" + now = time.time() + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + if not clients: + return + kept = [c for c in clients if (now - c.last_activity_at) <= CLIENT_ACTIVE_TIMEOUT] + if len(kept) != len(clients): + shared_mem_dict["active_swarm_clients"] = kept + + +def _reset_active_client_tracking(shared_mem_dict) -> None: + """Clear all active-client state.""" + shared_mem_dict["active_swarm_clients"] = [] + DEBUG = False # DEBUG = True @@ -233,24 +365,34 @@ def http_push_verbose_log(message: str, tag: str = "", config=None): logger.warning(f"Failed to push verbose log: {e}") -def http_update_rollout_pool_information(config, pool_info: CurrentBatchRolloutPoolInformation): +def http_update_rollout_pool_information_and_fetch_instruction( + config, pool_info: CurrentBatchRolloutPoolInformation +) -> SwarmClientInstruction | None: """ - Update the rollout pool information on the interchange server. + Update the rollout pool information on the interchange server, and fetch + the swarm server's view of currently-active clients and their + agree-to-sync-weight state. Args: config: The configuration object pool_info: CurrentBatchRolloutPoolInformation object with rollout statistics + + Returns: + SwarmClientInstruction with `active_clients` (List[ActiveSwarmClient]), + or None if the request failed. """ try: resp = httpx.post( - f"{get_interchange_server_url(config)}/update_current_batch_rollout_pool_information", + f"{get_interchange_server_url(config)}/update_current_batch_rollout_pool_information_and_fetch_instruction", json=pool_info.model_dump(), timeout=5 ) resp.raise_for_status() + return SwarmClientInstruction.model_validate(resp.json()) except Exception as e: if DEBUG: logger.warning(f"Failed to update rollout pool information: {e}") + return None def get_zmq_socket(config, episode_uuid: str, tag: str = ""): diff --git a/ajet/tuner_lib/experimental/oai_model_server.py b/ajet/tuner_lib/experimental/oai_model_server.py index 3d3a3091..fa823af4 100644 --- a/ajet/tuner_lib/experimental/oai_model_server.py +++ b/ajet/tuner_lib/experimental/oai_model_server.py @@ -301,6 +301,7 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # enable_swarm_mode if enable_swarm_mode: from ajet.tuner_lib.experimental.swarm_server import ep_key + from ajet.tuner_lib.experimental.interchange_utils import _refresh_client_activity assert shared_mem_dict is not None assert shared_mem_dict_lock is not None @@ -319,6 +320,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)): shared_mem_dict[ep_key(episode_uuid)] = es if es.episode_type == "eval": preserve_sampling_params = True + # chat-completion counts as activity for keeping the owning client + # in the swarm-server active list (no-op if it's not active yet). + _refresh_client_activity(es.client_uuid, shared_mem_dict) # For streaming, we process as non-streaming but return in streaming format original_stream = new_req.stream diff --git a/ajet/tuner_lib/experimental/swarm_client.py b/ajet/tuner_lib/experimental/swarm_client.py index 81de62d7..2de79437 100644 --- a/ajet/tuner_lib/experimental/swarm_client.py +++ b/ajet/tuner_lib/experimental/swarm_client.py @@ -24,6 +24,8 @@ EpisodeStatus, EpisodeBufferResponse, SwarmThrottlePolicy, + AgreeSyncWeightRequest, + BoolResponse, ) # general http timeout @@ -33,6 +35,13 @@ START_EPISODE_RETRY_DELAY = 15 TROTTLE_EPISODE_RETRY_DELAY = 2 WAIT_MORE_AVAIL_EPISODE_RETRY_DELAY = 2 +# agree_sync_weight retry policy. The call must succeed -- a dropped +# agreement can stall the trainer's stop condition. Retries cover both +# transport errors and server-side rejection (e.g. when a just-completed +# end_episode hasn't yet propagated to the server's active list). +AGREE_SYNC_WEIGHT_MAX_RETRIES = 60 +AGREE_SYNC_WEIGHT_RETRY_DELAY = 2.0 +DELAY_AFTER_AGREE_SYNC_WEIGHT = 30 def raise_for_status_with_detail(resp): try: @@ -767,6 +776,76 @@ def server_experiment_dir(self) -> str: except Exception as e: return "saved_experiments" + def agree_sync_weight(self) -> bool: + """Notify the swarm server that this client agrees to a weight sync. + + The server only accepts the agreement if this client is in its + active-client list (i.e. has end_episode'd at least one rewarded + episode since the last sync). Used together with the + `rollout_until_any_client_agree_sync_weight` / + `rollout_until_all_clients_agree_sync_weight` stop conditions so the + client can decide for itself when its current batch is "good enough". + + Important: this call retries on failure. A dropped agreement can + stall the trainer indefinitely (e.g. under "all clients agree"), and + the most common rejection -- "client not yet in active list" -- + clears itself once the just-finished end_episode propagates. Only + gives up after AGREE_SYNC_WEIGHT_MAX_RETRIES attempts, or if the + engine has left ROLLING/ROLLING_POST (the agreement would be wiped + by the server-side reset anyway). + + Returns: True if the agreement was registered, False after + exhausting retries (or after the engine left rolling state). + """ + last_failure = "" + for attempt in range(1, AGREE_SYNC_WEIGHT_MAX_RETRIES + 1): + engine_status, _ = self.get_engine_status() + if engine_status not in ("ENGINE.ROLLING", "ENGINE.ROLLING_POST"): + logger.warning( + f"agree_sync_weight: engine is {engine_status}, abandoning " + f"agreement (would be reset by server-side cleanup anyway)." + ) + return False + try: + req_obj = AgreeSyncWeightRequest(client_uuid=self.client_uuid) + resp = self._http_client.post( + f"{self.server_url}/agree_sync_weight", + json=req_obj.model_dump(), + timeout=10, + ) + raise_for_status_with_detail(resp) + data = BoolResponse.model_validate(resp.json()) + if data.success: + if self.verbose: + self.logger_info( + f"agree_sync_weight: registered with server " + f"(attempt {attempt})" + ) + # time.sleep(DELAY_AFTER_AGREE_SYNC_WEIGHT) + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING_POST") + return True + last_failure = data.failure_reason + logger.warning( + f"agree_sync_weight rejected (attempt " + f"{attempt}/{AGREE_SYNC_WEIGHT_MAX_RETRIES}): " + f"{data.failure_reason}. Retrying in " + f"{AGREE_SYNC_WEIGHT_RETRY_DELAY}s..." + ) + except Exception as e: + last_failure = str(e) + if self._should_refresh_client_on_error(e): + self._refresh_http_client() + logger.error( + f"agree_sync_weight errored (attempt " + f"{attempt}/{AGREE_SYNC_WEIGHT_MAX_RETRIES}): {e}. Retrying..." + ) + time.sleep(AGREE_SYNC_WEIGHT_RETRY_DELAY) + logger.error( + f"agree_sync_weight: gave up after {AGREE_SYNC_WEIGHT_MAX_RETRIES} " + f"attempts. Last failure: {last_failure}" + ) + return False + def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: """ Get the current batch rollout pool information from the Swarm server. diff --git a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py index 06355e50..1300133c 100644 --- a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py +++ b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py @@ -19,3 +19,4 @@ class CurrentBatchRolloutPoolInformation(BaseModel): global_step: int | None = None booting_start_time: float | None = None # timestamp when ENGINE.BOOTING started training_model_path: str | None = None # model path from synced training config + swarm_client_instruction: dict = {} diff --git a/ajet/tuner_lib/experimental/swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py index a4c1a346..81365e6c 100644 --- a/ajet/tuner_lib/experimental/swarm_server.py +++ b/ajet/tuner_lib/experimental/swarm_server.py @@ -12,7 +12,7 @@ from typing import Coroutine, Optional, Tuple, List from ajet.utils.process_killer import kill_process_tree from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation -from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE +from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE, CLIENT_ACTIVE_TIMEOUT from ajet.tuner_lib.experimental.interchange_utils import ( SyncTrainConfigRequest, ClaimEpisodeRequest, @@ -30,9 +30,17 @@ PushVerboseLogRequest, VerboseLogEntry, VerboseLogsResponse, + AgreeSyncWeightRequest, + SwarmClientInstruction, + ActiveSwarmClient, + _refresh_client_activity, + _register_active_client, + _expire_inactive_clients, + _reset_active_client_tracking, VALID_STATUSES, ) + VERBOSE_LOG_TTL_SECONDS = 30.0 VERBOSE_LOG_MAX_ENTRIES = 50 @@ -69,6 +77,11 @@ def register_enable_swarm_mode_routes( if "current_batch_rollout_pool_information" not in shared_mem_dict: shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation() + # active swarm client tracking (List[ActiveSwarmClient]; helpers live in + # interchange_utils) + if "active_swarm_clients" not in shared_mem_dict: + shared_mem_dict["active_swarm_clients"] = [] + # ------------------------------------------------------------------------------------------------ # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- # --------------------------------- claimed -> unclaimed ---------------------------------------- @@ -227,6 +240,7 @@ async def register_episode_ready_listener(): while True: await asyncio.sleep(10) # check every 10 seconds await find_claimed_episodes_that_need_to_be_unclaimed() + _expire_inactive_clients(shared_mem_dict) # read_all_episode_status() if DEBUG: _write_swarm_server_dynamic_log(shared_mem_dict) @@ -280,6 +294,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): shared_mem_dict["unclaimed_episodes"] = [] logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + # reset active-client tracking (cleared each time we leave ROLLING/ + # ROLLING_POST -- i.e. on entering WEIGHT_SYNCING etc.) + _reset_active_client_tracking(shared_mem_dict) + # -------------------------------------------------------------------------------------- # -------------------------- fastapi routes -------------------------------------------- # -------------------------------------------------------------------------------------- @@ -602,6 +620,11 @@ async def claim_episode(req: ClaimEpisodeRequest): if VERBOSE: logger.info(f"Running [{episode_uuid}]: /claim_episode") + # begin_episode counts as activity for keeping a client in the + # active list (only refreshes if already active; first activation + # comes from a successful end_episode). + _refresh_client_activity(req.client_uuid, shared_mem_dict) + return ClaimEpisodeResponse( success=True, client_uuid=req.client_uuid, @@ -668,6 +691,8 @@ async def end_episode(req: EndEpisodeRequest): shared_mem_dict, shared_mem_dict_lock, ) + # successful, non-abort end_episode marks the client "active" + _register_active_client(client_uuid, shared_mem_dict) elif episode_type == "eval": if engine_status in ["ENGINE.ROLLING"]: @@ -742,11 +767,23 @@ async def get_episode_buffer(): result = [v for k, v in shared_mem_dict.items() if is_key_episode_status(k)] return EpisodeBufferResponse(buffer=result) - @app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse) - async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation): - """Update the current batch rollout pool information.""" + @app.post( + "/update_current_batch_rollout_pool_information_and_fetch_instruction", + response_model=SwarmClientInstruction, + ) + async def update_current_batch_rollout_pool_information_and_fetch_instruction( + req: CurrentBatchRolloutPoolInformation, + ): + """Update pool information and return the active-client instruction. + + The trainer pushes its latest pool snapshot here every few seconds; + in the same call we hand back the server-maintained + `active_swarm_clients` list so the trainer can evaluate + `rollout_until_*_agree_sync_weight` stop conditions without an extra + round-trip. + """ if DEBUG: - logger.info(f"Running /update_current_batch_rollout_pool_information") + logger.info(f"Running /update_current_batch_rollout_pool_information_and_fetch_instruction") try: with shared_mem_dict_lock: # Ignore fields that are only maintained in shared_mem_dict @@ -755,10 +792,62 @@ async def update_current_batch_rollout_pool_information(req: CurrentBatchRollout req.global_step = None req.completed_tasks_client_uuids = {} shared_mem_dict["current_batch_rollout_pool_information"] = req - return BoolResponse(success=True) + instruction = SwarmClientInstruction( + active_clients=list(shared_mem_dict.get("active_swarm_clients", [])) + ) + return instruction except Exception as e: logger.error(f"Error updating current batch rollout pool information: {e}") - return BoolResponse(success=False, failure_reason=str(e)) + return SwarmClientInstruction() + + AGREE_SYNC_WEIGHT_VALID_METHODS = ( + "rollout_until_any_client_agree_sync_weight", + "rollout_until_all_clients_agree_sync_weight", + ) + + @app.post("/agree_sync_weight", response_model=BoolResponse) + async def agree_sync_weight(req: AgreeSyncWeightRequest): + """Mark a client as having agreed to the next weight sync. + + Only counts when the client is currently in the active list (otherwise + the agreement would be silently expired anyway). The set is cleared + whenever the engine leaves ROLLING/ROLLING_POST. + + Refuses the call unless the trainer is configured with one of the + agree-driven sample-collection methods, since under any other policy + the agreement would have no effect on when the trainer stops. + """ + if VERBOSE: + logger.info(f"Running /agree_sync_weight: {req.client_uuid}") + client_uuid = req.client_uuid + if not client_uuid: + return BoolResponse(success=False, failure_reason="client_uuid required") + pool_info: CurrentBatchRolloutPoolInformation = shared_mem_dict.get( + "current_batch_rollout_pool_information", + CurrentBatchRolloutPoolInformation(), + ) + assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, ( + f"agree_sync_weight is only valid when " + f"ajet.swarm_mode_sample_collection_method is one of " + f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently " + f"running with '{pool_info.sample_collection_method}'." + ) + with shared_mem_dict_lock: + clients: List[ActiveSwarmClient] = list( + shared_mem_dict.get("active_swarm_clients", []) + ) + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + if not c.allowed_sync_weight: + clients[i] = c.model_copy(update={"allowed_sync_weight": True}) + shared_mem_dict["active_swarm_clients"] = clients + return BoolResponse(success=True) + return BoolResponse( + success=False, + failure_reason=( + f"Client {client_uuid} is not in the active list -- it must have completed at least one rewarded (non-abort) episode since the last weight sync before agreeing." + ), + ) @app.get("/get_current_batch_rollout_pool_information", response_model=CurrentBatchRolloutPoolInformation) async def get_current_batch_rollout_pool_information(): @@ -773,6 +862,9 @@ async def get_current_batch_rollout_pool_information(): pool_info.global_step = shared_mem_dict.get("global_step", None) pool_info.booting_start_time = shared_mem_dict.get("booting_start_time", None) pool_info.training_model_path = shared_mem_dict.get("training_model_path", None) + pool_info.swarm_client_instruction = SwarmClientInstruction( + active_clients=list(shared_mem_dict.get("active_swarm_clients", [])) + ).model_dump() # Build running_episode_details for claimed episodes running_episode_details = {} diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index d076e330..fdf90aa0 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -117,6 +117,17 @@ def create_header( header_text = Text() header_text.append("AgentJet Swarm Overwatch", style="bold cyan") header_text.append(f"\nServer: {self.server_url}", style="dim") + + instr = info.swarm_client_instruction if info else {} + active_clients = instr.get("active_clients", []) + agreed = sum(1 for c in active_clients if c.get("allowed_sync_weight")) + total = len(active_clients) + header_text.append(f"\nActive Clients: {total}", style="bold white") + if total: + parts = [f"{c.get('client_uuid','?')[:8]}{'✓' if c.get('allowed_sync_weight') else ''}" for c in active_clients[:8]] + suffix = f", +{total - 8} more" if total > 8 else "" + header_text.append(f" [{', '.join(parts)}{suffix}]", style="cyan") + header_text.append(f"\nCurrent Time: {now}", style="green") header_text.append(f" | Last Update: {last_update}", style="yellow") header_text.append(f" | Refresh: {self.refresh_interval}s", style="blue") diff --git a/appworld_swarm_results/val_results.md b/appworld_swarm_results/val_results.md new file mode 100644 index 00000000..7fc5312c --- /dev/null +++ b/appworld_swarm_results/val_results.md @@ -0,0 +1,77 @@ + +## Step 0 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 55 +- num_pass_n_tasks: 55 +- task_pass_rate@1: 96.49% +- task_pass_rate@1: 96.49% +- mean_reward: 0.3759 +- std_reward: 0.4345 +- n_rollouts: 57 + +## Step 0 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 54 +- num_pass_n_tasks: 54 +- task_pass_rate@1: 94.74% +- task_pass_rate@1: 94.74% +- mean_reward: 0.3977 +- std_reward: 0.4565 +- n_rollouts: 57 + +## Step 10 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 55 +- num_pass_n_tasks: 55 +- task_pass_rate@1: 96.49% +- task_pass_rate@1: 96.49% +- mean_reward: 0.5378 +- std_reward: 0.5341 +- n_rollouts: 57 + +## Step 20 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 57 +- num_pass_n_tasks: 57 +- task_pass_rate@1: 100.00% +- task_pass_rate@1: 100.00% +- mean_reward: 0.5450 +- std_reward: 0.5293 +- n_rollouts: 57 + +## Step 30 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 55 +- num_pass_n_tasks: 55 +- task_pass_rate@1: 96.49% +- task_pass_rate@1: 96.49% +- mean_reward: 0.6261 +- std_reward: 0.5798 +- n_rollouts: 57 + +## Step 40 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 55 +- num_pass_n_tasks: 55 +- task_pass_rate@1: 96.49% +- task_pass_rate@1: 96.49% +- mean_reward: 0.6617 +- std_reward: 0.5789 +- n_rollouts: 57 + +## Step 50 +- pass_n: 1 +- total_tasks: 57 +- num_all_success_tasks: 56 +- num_pass_n_tasks: 56 +- task_pass_rate@1: 98.25% +- task_pass_rate@1: 98.25% +- mean_reward: 0.7028 +- std_reward: 0.5940 +- n_rollouts: 57 diff --git a/tutorial/example_appworld/appworld.md b/tutorial/example_appworld/appworld.md index 3a759a8a..d9edc047 100644 --- a/tutorial/example_appworld/appworld.md +++ b/tutorial/example_appworld/appworld.md @@ -1,8 +1,35 @@ ## Run Appworld AgentScope Agent -### 1. Prepare dataset - -Please download `env_service` and `appworld`. For specific steps, please refer to [EnvService Documentation](https://code.alibaba-inc.com/EconML/EnvService) +### 1. Install Appworld +``` +def install_appworld(): + # run: + # `rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp` + import shutil + + if os.path.exists("/tmp/pack_all_in_one"): + shutil.rmtree("/tmp/pack_all_in_one") + if os.path.exists("./appworld_pack_v3.tar.gz"): + os.remove("./appworld_pack_v3.tar.gz") + subprocess.run( + [ + "wget", + "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz", + ] + ) + subprocess.run( + [ + "tar", + "-xzf", + "./appworld_pack_v3.tar.gz", + "-C", + "/tmp", + ] + ) + # write + os.environ["APPWORLD_PATH"] = "/tmp/pack_all_in_one" + os.environ["APPWORLD_SCRIPT"] = "bash EnvService/env_sandbox/appworld.sh" +``` ### 2. Prepare AgentScope Workflow diff --git a/tutorial/example_appworld_swarm/README.md b/tutorial/example_appworld_swarm/README.md new file mode 100644 index 00000000..5d1211aa --- /dev/null +++ b/tutorial/example_appworld_swarm/README.md @@ -0,0 +1,45 @@ +## AppWorld swarm mode + +Swarm-mode rewrite of `tutorial/example_appworld`. +The training engine runs remotely (server side), while task enumeration, +env_service instance lifecycle and reward evaluation all happen locally +in the rollout client. + +Files: +- `appworld_swarm.py` — workflow + lightweight `EnvClient` gym wrapper +- `agent_roll.py` — rollout driver (calls `begin_episode` / `end_episode`) +- `appworld.yaml` — swarm-mode training config + +Required env vars (with sensible defaults): +- `AJET_SWARM_URL` — swarm server URL (default `http://localhost:10086`) +- `APPWORLD_ENV_URL` — appworld env_service URL (default `http://127.0.0.1:8080`) +- `APPWORLD_ENV_TYPE` — env_type passed to env_service (default `appworld`) +- `APPWORLD_TRAINING_SPLIT` — train split for `get_env_profile` (default `train`) +- `APPWORLD_VALIDATION_SPLIT` — eval split for `get_env_profile` (default `dev`) +- `APPWORLD_MAX_STEPS` — per-episode step cap (default `25`) +- `APPWORLD_EVAL_INTERVAL` — run eval every N global steps (default `10`) +- `APPWORLD_EVAL_K` — rollouts per eval task, pass@k (default `1`) +- `APPWORLD_TOTAL_TRAINING_STEPS`— hard cap on global steps (default `200`) +- `APPWORLD_RESULT_DIR` — where eval logs / `val_results.md` are written (default `./appworld_swarm_results`) +- `APPWORLD_MAX_ENV_WORKER` — max parallel env workers for both train and eval (default `64`) + + +## Run swarm + +``` +tmux new-session -d -s "SWARM_SERVER" +tmux send-keys -t "SWARM_SERVER" "cd /mnt/data_cpfs/qingxu.fu/agentjet/hello-agentjet" Enter +tmux send-keys -t "SWARM_SERVER" "source .venv/bin/activate" Enter +tmux send-keys -t "SWARM_SERVER" "export SETUPTOOLS_USE_DISTUTILS=local" Enter +tmux send-keys -t "SWARM_SERVER" "ajet-swarm start" Enter +ta "SWARM_SERVER" + + +tmux new-session -d -s "SWARM_CLIENT" +tmux send-keys -t "SWARM_CLIENT" "cd /mnt/data_cpfs/qingxu.fu/agentjet/hello-agentjet" Enter +tmux send-keys -t "SWARM_CLIENT" "source .venv/bin/activate" Enter +tmux send-keys -t "SWARM_CLIENT" "export SETUPTOOLS_USE_DISTUTILS=local" Enter +tmux send-keys -t "SWARM_CLIENT" "sleep 30s" Enter +tmux send-keys -t "SWARM_CLIENT" "python -m tutorial.example_appworld_swarm.agent_roll" Enter +ta "SWARM_CLIENT" +``` diff --git a/tutorial/example_appworld_swarm/agent_roll.py b/tutorial/example_appworld_swarm/agent_roll.py new file mode 100644 index 00000000..4b86255b --- /dev/null +++ b/tutorial/example_appworld_swarm/agent_roll.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- + +# python -m tutorial.example_appworld_swarm.agent_roll + +import os +import statistics +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Generator, List + +from tqdm import tqdm + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from ajet.utils.env_service_client.env_client_ng import EnvClient +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +ENV_URL = os.getenv("APPWORLD_ENV_URL", "http://127.0.0.1:8080") +ENV_TYPE = os.getenv("APPWORLD_ENV_TYPE", "appworld") +TRAINING_SPLIT = os.getenv("APPWORLD_TRAINING_SPLIT", "train") +VALIDATION_SPLIT = os.getenv("APPWORLD_VALIDATION_SPLIT", "dev") +MAX_STEPS = int(os.getenv("APPWORLD_MAX_STEPS", "25")) + +EVAL_INTERVAL = int(os.getenv("APPWORLD_EVAL_INTERVAL", "10")) +EVAL_K = int(os.getenv("APPWORLD_EVAL_K", "1")) +TOTAL_TRAINING_STEPS = int(os.getenv("APPWORLD_TOTAL_TRAINING_STEPS", "200")) +RESULT_DIR = os.getenv("APPWORLD_RESULT_DIR", "./appworld_swarm_results") +MAX_ENV_WORKER = int(os.getenv("APPWORLD_MAX_ENV_WORKER", "64")) + + +def get_appworld_tasks(split: str) -> List[Task]: + """Enumerate appworld task ids from env_service for the given split. + + The swarm client owns task generation, so we hit env_service directly + (rather than going through `EnvServiceTaskReader`) to keep the config + surface flat. + """ + env_client = EnvClient(base_url=ENV_URL) + task_id_array = env_client.get_env_profile(ENV_TYPE, split=split) + if len(task_id_array) == 0: + raise ValueError( + f"No task_id found for env_type={ENV_TYPE}, split={split}, " + f"check connection to {ENV_URL}" + ) + return [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=ENV_TYPE, + metadata={}, + ) + for task_id in task_id_array + ] + + +def generate_training_tasks() -> Generator[Task, None, None]: + for task in get_appworld_tasks(TRAINING_SPLIT): + yield task + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + import asyncio + from tutorial.example_appworld_swarm.appworld_swarm import ExampleAgentScopeWorkflow + workflow = ExampleAgentScopeWorkflow( + env_url=ENV_URL, + env_type=ENV_TYPE, + max_steps=MAX_STEPS, + ) + return asyncio.run(workflow.execute(task, api_baseurl_key)) + + +def main(): + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_appworld_swarm/appworld.yaml", + algorithm="grpo", + experiment_name="appworld_swarm_14b", + max_env_worker=MAX_ENV_WORKER, + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + # force_restart=True, + ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + os.makedirs(RESULT_DIR, exist_ok=True) + eval_log_path = os.path.join(RESULT_DIR, "eval_results.log") + val_result_path = os.path.join(RESULT_DIR, "val_results.md") + + eval_tasks = get_appworld_tasks(VALIDATION_SPLIT) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks (split={VALIDATION_SPLIT})") + + def rollout(task: Task) -> float: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=600) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return workflow_output.reward + + def eval_rollout(task: Task) -> float: + episode_uuid, api_baseurl_key = swarm_worker.begin_episode( + discard_episode_timeout=600, episode_type="eval" + ) + try: + workflow_output = execute_agent(task, api_baseurl_key) + return workflow_output.reward + finally: + # eval samples must NOT be fed back into the training pool + swarm_worker.abort_episode(episode_uuid) + + def run_eval(n_global_step: int): + if not eval_tasks: + return + k = EVAL_K + total_rollouts = len(eval_tasks) * k + print(f"\n[EVAL @ step {n_global_step}] {len(eval_tasks)} tasks x {k} (pass@{k})...") + per_task_rewards: List[List[float]] = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=MAX_ENV_WORKER) as eval_executor: + future_to_idx = { + eval_executor.submit(eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if not flat: + print(f"[EVAL @ step {n_global_step}] no valid rewards") + return + + avg = sum(flat) / len(flat) + std_reward = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + # Full success requires raw_reward >= 1 (final_reward >= 1.5). + # Partial-credit rollouts have 0 < final_reward <= 0.5, so they must NOT + # count as passes; see EnvServiceJudge.compute_reward. + SUCCESS_THRESHOLD = 1.0 + pass1 = sum(1 for r in flat if r >= SUCCESS_THRESHOLD) / len(flat) + num_all_success_tasks = sum( + 1 + for rs in per_task_rewards + if rs and all((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + num_pass_n_tasks = sum( + 1 + for rs in per_task_rewards + if any((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] mean_reward={avg:.4f} std_reward={std_reward:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step}\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std_reward:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") + + # step-0 eval (swarm mode does not support val_before_train) + last_eval_step = 0 + run_eval(0) + + executor = PeriodicDrainThreadPoolExecutor( + workers=GRPO_N * REMOTE_BATCH_SIZE, max_parallel=64, auto_retry=True + ) + + n_global_step = 0 + for _ in range(NUM_EPOCH): + for task in generate_training_tasks(): + for _ in range(GRPO_N): + # `submit_with_periodic_drain` returns drained results only when the + # in-flight pool was actually drained on this submission. Each drain + # boundary corresponds to a fully-collected local batch -- exactly + # when this client should agree to a weight sync under + # `rollout_until_all_clients_agree_sync_weight`. + _, drained_results = executor.submit_with_periodic_drain( + fn=rollout, task=task + ) + if drained_results: + swarm_worker.agree_sync_weight() + + n_global_step = swarm_worker.get_global_step() + if n_global_step >= last_eval_step + EVAL_INTERVAL: + run_eval(n_global_step) + last_eval_step = n_global_step + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + print("[INFO] Training complete.") + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_appworld_swarm/appworld.yaml b/tutorial/example_appworld_swarm/appworld.yaml new file mode 100644 index 00000000..da4160a3 --- /dev/null +++ b/tutorial/example_appworld_swarm/appworld.yaml @@ -0,0 +1,84 @@ +# ------------------ main config ------------------ +# Swarm-mode counterpart of tutorial/example_appworld/appworld.yaml. +# Settings unrelated to swarm wiring are kept identical to the original yaml. +ajet: + project_name: example_appworld_swarm + experiment_dir: "auto" # {exp-dir}/{experiment_name} + + task_judge: + # reward is computed by the swarm workflow on the client side + judge_protocol: null + + task_reader: + # tasks are enumerated by the swarm client (env_service is queried in agent_roll.py) + type: random_dummy + + model: + # ✨ select model to be trained (matches original appworld.yaml) + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # workflow is driven from the swarm client, not from the server + user_workflow: null + force_disable_toolcalls: True + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 64 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 1 + max_fastapi_threads: 512 + max_inference_tracker_threads: 64 + already_started: False # do not edit, used by `swarm` + + # Stop the rollout phase only when **every** active swarm client has called + # `SwarmClient.agree_sync_weight()`. The driver (see `agent_roll.py`) calls + # `agree_sync_weight()` each time `submit_with_periodic_drain` actually + # drains, so each weight sync lines up with a complete drain boundary. + swarm_mode_sample_collection_method: "rollout_until_all_clients_agree_sync_weight" + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default + - ajet_default + - _self_ diff --git a/tutorial/example_appworld_swarm/appworld_swarm.py b/tutorial/example_appworld_swarm/appworld_swarm.py new file mode 100644 index 00000000..4b1e6c3f --- /dev/null +++ b/tutorial/example_appworld_swarm/appworld_swarm.py @@ -0,0 +1,168 @@ +from typing import Any, Tuple + +from agentscope.message import Msg +from loguru import logger + +from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.utils.env_service_client.env_client_ng import EnvClient + + +class AppworldGymWrapper: + """Mirror of ajet.task_rollout.resource_keeper.BaseGymEnv for swarm-mode clients. + + The swarm runner does not build a `gym_env` for us, so we wrap `EnvClient` + directly to keep the `step()/evaluate()` surface that the agent loop expects. + """ + + def __init__(self, env_client: EnvClient, episode_uuid: str): + self.env_client = env_client + self.episode_uuid = episode_uuid + + def step(self, action: dict) -> Tuple[Any, float, bool, dict]: + if not isinstance(action["content"], str): + try: + action["content"] = action["content"][0]["text"] + except Exception: + logger.exception( + f"Failed to parse action content from agentscope output. {action['content']}" + ) + action["content"] = str(action["content"]) + + env_output = self.env_client.step( + instance_id=self.episode_uuid, + action=action, + ) + obs: Any = "" + reward: float = 0 + info: dict = {} + if isinstance(env_output["state"], list): + obs = env_output["state"] + reward = env_output["reward"] + info = env_output["info"] + else: + if ("content" not in env_output["state"]) and ("error" in env_output["state"]): + obs = f"[Error from environment: {env_output['error']}]" + elif env_output["state"].get("content", "") == "": + obs = "Warning: the environment does not provide any feedback, please provide valid input and try again." + else: + obs = env_output["state"]["content"] + terminate = env_output["is_terminated"] + return obs, reward, terminate, info + + def evaluate(self, params=None): + return self.env_client.evaluate(self.episode_uuid, params=params or {"sparse": False}) + + +class ExampleAgentScopeWorkflow: + """Swarm-mode appworld workflow. + + Unlike the in-process workflow (which receives a fully initialized + `WorkflowTask` with `gym_env` populated by the framework), the swarm + client is responsible for the env_service instance lifecycle and reward + evaluation locally. + """ + + def __init__( + self, + env_url: str = "http://127.0.0.1:8080", + env_type: str = "appworld", + max_steps: int = 25, + ): + self.env_url = env_url + self.env_type = env_type + self.max_steps = max_steps + + async def execute(self, task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + + episode_uuid = api_baseurl_key.episode_uuid + env_client = EnvClient(base_url=self.env_url) + + try: + create_response = env_client.create_instance( + env_type=self.env_type, + task_id=task.task_id, + instance_id=episode_uuid, + params={}, + ) + state_message = create_response["state"] + if isinstance(state_message, dict): + raw_init_messages = [state_message] + elif isinstance(state_message, list): + raw_init_messages = state_message + else: + raise ValueError( + f"state_message should be dict or list, got {type(state_message)}" + ) + + if len(raw_init_messages) >= 2: + first_msg, init_messages = raw_init_messages[0], raw_init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + init_messages = [] + + interaction_message = [] + for msg in init_messages: + interaction_message.append( + Msg( + name=msg.get("name", "user"), + content=msg.get("content", ""), + role=msg.get("role", "user"), + ) + ) + + agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg.get("content", "You're a helpful assistant."), + model=api_baseurl_key.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, + ) + agent.set_console_output_enabled(False) + + env = AppworldGymWrapper(env_client, episode_uuid) + step = 0 + for step in range(self.max_steps): + reply_message = await agent(interaction_message) + obs, _, terminate, _ = env.step( + action={"content": reply_message.content, "role": "assistant"} + ) + interaction_message = Msg(name="env", content=obs, role="user") + if terminate: + break + + try: + raw_reward = env.evaluate(params={"sparse": False}) + except Exception: + logger.exception("Evaluation failed; defaulting raw_reward=0.0") + raw_reward = 0.0 + + # mirror EnvServiceJudge.compute_reward + if raw_reward >= 1: + is_success = True + final_reward = 1.0 + raw_reward * 0.5 + else: + is_success = False + final_reward = 0.0 + raw_reward * 0.5 + + return WorkflowOutput( + reward=final_reward, + is_success=is_success, + metadata={"total_step": step}, + ) + except Exception: + logger.bind(exception=True).exception( + f"Error during appworld swarm episode (task_id={task.task_id})." + ) + return WorkflowOutput(reward=0.0, is_success=False, metadata={"total_step": 0}) + finally: + try: + env_client.release_instance(episode_uuid) + except Exception: + logger.exception("Failed to release env instance") From c265fba644ec6cefe01e0ae0220304da7a7dbdd8 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 8 May 2026 16:05:15 +0800 Subject: [PATCH 2/2] Add example_cocktail_rl_v2 with configurable batch ratios and dynamic scheduling - Implemented AppWorld and AIME swarm clients for the new cocktail RL version. - Introduced CocktailV2Config as a single source of truth for configuration values. - Created train_appworld_as_swarm_client_0 and train_aime_as_swarm_client_1 scripts for running the respective clients. - Added cocktail_v2_runner to manage shared functionality between clients. - Included readme.md for setup instructions and configuration details. - Enhanced evaluation and logging mechanisms for better performance tracking. --- .gitignore | 2 + ajet/copilot/monitor-with-tmux/SKILL.md | 8 + ajet/default_config/ajet_config_schema.py | 1 + ajet/default_config/ajet_default.yaml | 1 + ajet/default_config/ajet_swarm_default.yaml | 1 + ajet/swarm_cli.py | 29 ++ ajet/task_rollout/async_llm_bridge.py | 6 +- .../utils/env_service_client/env_client_ng.py | 10 +- appworld_swarm_results/val_results.md | 77 ----- tutorial/example_appworld/appworld.md | 40 +-- .../example_appworld_swarm/appworld_swarm.py | 63 ++-- .../example_cocktail_rl/cocktail_rl_conf.yaml | 103 +++++++ tutorial/example_cocktail_rl/config_diff.md | 13 + .../train_aime_as_swarm_client_1.py | 287 ++++++++++++++++++ .../train_appworld_as_swarm_client_0.py | 245 +++++++++++++++ .../cocktail_v2_config.py | 167 ++++++++++ .../cocktail_v2_runner.py | 242 +++++++++++++++ tutorial/example_cocktail_rl_v2/readme.md | 15 + .../train_aime_as_swarm_client_1.py | 140 +++++++++ .../train_appworld_as_swarm_client_0.py | 194 ++++++++++++ 20 files changed, 1498 insertions(+), 146 deletions(-) delete mode 100644 appworld_swarm_results/val_results.md create mode 100644 tutorial/example_cocktail_rl/cocktail_rl_conf.yaml create mode 100644 tutorial/example_cocktail_rl/config_diff.md create mode 100644 tutorial/example_cocktail_rl/train_aime_as_swarm_client_1.py create mode 100644 tutorial/example_cocktail_rl/train_appworld_as_swarm_client_0.py create mode 100644 tutorial/example_cocktail_rl_v2/cocktail_v2_config.py create mode 100644 tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py create mode 100644 tutorial/example_cocktail_rl_v2/readme.md create mode 100644 tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py create mode 100644 tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py diff --git a/.gitignore b/.gitignore index 00e50326..892f615f 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,5 @@ research_*.json research_*.jsonc daemon_logs* paper +val_results.md +cocktail_vs_separate* diff --git a/ajet/copilot/monitor-with-tmux/SKILL.md b/ajet/copilot/monitor-with-tmux/SKILL.md index 7da08a71..ead5a48b 100644 --- a/ajet/copilot/monitor-with-tmux/SKILL.md +++ b/ajet/copilot/monitor-with-tmux/SKILL.md @@ -180,3 +180,11 @@ $ python3 /tmp/tmux_wait.py ajet_session 240 && tmux capture-pane -t ajet_sessio tmux kill-session -t ajet_session ``` + + +## For AgentJet Swarm + +- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients +- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time +- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort). +- For AgentJet, always use tmux session name that starts with `ajet-*` diff --git a/ajet/default_config/ajet_config_schema.py b/ajet/default_config/ajet_config_schema.py index f0bdb35b..a90800e6 100644 --- a/ajet/default_config/ajet_config_schema.py +++ b/ajet/default_config/ajet_config_schema.py @@ -40,6 +40,7 @@ class AjetModel: class AjetData: max_prompt_length: int = 3000 max_response_length: int = 15000 + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: int = 32 diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index c84077ed..8bf1d689 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -16,6 +16,7 @@ ajet: # max number of tokens for response max_response_length: 15000 # how many tasks per training batch + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: 32 # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) diff --git a/ajet/default_config/ajet_swarm_default.yaml b/ajet/default_config/ajet_swarm_default.yaml index 5890e94e..6078362e 100644 --- a/ajet/default_config/ajet_swarm_default.yaml +++ b/ajet/default_config/ajet_swarm_default.yaml @@ -67,6 +67,7 @@ ajet: # max number of tokens for response max_response_length: 15000 # how many tasks per training batch + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: 32 # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 54787386..d91e29d0 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from loguru import logger +from ajet.utils.cleaner import fast_kill_by_keyword_bash from ajet.utils.config_utils import prepare_experiment_config from ajet.utils.launch_utils import ( dict_to_namespace, @@ -41,6 +42,21 @@ def start_swarm_server(env, config, port): def cmd_start(args): """Handle the 'start' subcommand.""" + if args.autokill: + args.kill = "ray|vllm|VLLM|python" + + if args.kill: + logger.info(f"Killing processes matching keywords: {args.kill}") + for keyword in args.kill.split("|"): + logger.info(f"Killing processes matching keyword: {keyword}") + killed_pids = fast_kill_by_keyword_bash(keyword) + if killed_pids: + logger.success( + f"Successfully killed processes with PIDs: {killed_pids}" + ) + else: + logger.warning(f"No processes found matching keyword: {keyword}") + # Use default config if not provided exp_base_dir = args.exp_dir or DEFAULT_DIR if not args.conf: @@ -126,6 +142,19 @@ def main(): required=False, help="Debug tags; enables Ray post-mortem and DEBUG_TAGS env", ) + parser_start.add_argument( + "--kill", + type=str, + default="", + required=False, + help="list of keywords for killing processes", + ) + parser_start.add_argument( + "--autokill", + action="store_true", + default=False, + help="Kill system processes (ray + vllm + python) that may block the current experiment", + ) parser_start.set_defaults(func=cmd_start) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index 685458a9..bd853556 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -122,7 +122,7 @@ async def llm_chat_verl( ): parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore - parsed_tool_calls = parsed_tool_calls.model_dump() + parsed_tool_calls = parsed_tool_calls.model_dump(mode='json') model_called = parsed_tool_calls["tools_called"] if model_called: @@ -155,7 +155,7 @@ async def llm_chat_verl( "completion_tokens": len(token_array), # type: ignore "total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore } - # from ajet import bp; bp("DECODE") + return { "role": "assistant", "request_id": request_id, @@ -327,7 +327,7 @@ async def chat_completion_request( episode_uuid: str, ): from openai.types.chat.chat_completion import ChatCompletion - req_as_dict = req.model_dump() + req_as_dict = req.model_dump(mode='json') # infer + process with context tracker llm_output = await self.run_infer( diff --git a/ajet/utils/env_service_client/env_client_ng.py b/ajet/utils/env_service_client/env_client_ng.py index a8e1112f..3aba8c0b 100644 --- a/ajet/utils/env_service_client/env_client_ng.py +++ b/ajet/utils/env_service_client/env_client_ng.py @@ -207,9 +207,13 @@ def call(): messages=action, params=params, ) - return resp["data"] + data = resp["data"] + while "data" in data and "state" not in data: + data = data["data"] + data["state"] = data["state"][0] + return data - res = retry_call( + return retry_call( call, max_retry=max_retry, fail_return=fallback, @@ -217,8 +221,6 @@ def call(): instance_id=instance_id, action_name="step", ) - res["state"] = res["state"][0] - return res def evaluate( self, diff --git a/appworld_swarm_results/val_results.md b/appworld_swarm_results/val_results.md deleted file mode 100644 index 7fc5312c..00000000 --- a/appworld_swarm_results/val_results.md +++ /dev/null @@ -1,77 +0,0 @@ - -## Step 0 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 55 -- num_pass_n_tasks: 55 -- task_pass_rate@1: 96.49% -- task_pass_rate@1: 96.49% -- mean_reward: 0.3759 -- std_reward: 0.4345 -- n_rollouts: 57 - -## Step 0 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 54 -- num_pass_n_tasks: 54 -- task_pass_rate@1: 94.74% -- task_pass_rate@1: 94.74% -- mean_reward: 0.3977 -- std_reward: 0.4565 -- n_rollouts: 57 - -## Step 10 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 55 -- num_pass_n_tasks: 55 -- task_pass_rate@1: 96.49% -- task_pass_rate@1: 96.49% -- mean_reward: 0.5378 -- std_reward: 0.5341 -- n_rollouts: 57 - -## Step 20 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 57 -- num_pass_n_tasks: 57 -- task_pass_rate@1: 100.00% -- task_pass_rate@1: 100.00% -- mean_reward: 0.5450 -- std_reward: 0.5293 -- n_rollouts: 57 - -## Step 30 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 55 -- num_pass_n_tasks: 55 -- task_pass_rate@1: 96.49% -- task_pass_rate@1: 96.49% -- mean_reward: 0.6261 -- std_reward: 0.5798 -- n_rollouts: 57 - -## Step 40 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 55 -- num_pass_n_tasks: 55 -- task_pass_rate@1: 96.49% -- task_pass_rate@1: 96.49% -- mean_reward: 0.6617 -- std_reward: 0.5789 -- n_rollouts: 57 - -## Step 50 -- pass_n: 1 -- total_tasks: 57 -- num_all_success_tasks: 56 -- num_pass_n_tasks: 56 -- task_pass_rate@1: 98.25% -- task_pass_rate@1: 98.25% -- mean_reward: 0.7028 -- std_reward: 0.5940 -- n_rollouts: 57 diff --git a/tutorial/example_appworld/appworld.md b/tutorial/example_appworld/appworld.md index d9edc047..cd91deac 100644 --- a/tutorial/example_appworld/appworld.md +++ b/tutorial/example_appworld/appworld.md @@ -1,35 +1,19 @@ ## Run Appworld AgentScope Agent -### 1. Install Appworld +### 1. Install and Run Appworld + +- Install: +``` +rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp ``` -def install_appworld(): - # run: - # `rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp` - import shutil - - if os.path.exists("/tmp/pack_all_in_one"): - shutil.rmtree("/tmp/pack_all_in_one") - if os.path.exists("./appworld_pack_v3.tar.gz"): - os.remove("./appworld_pack_v3.tar.gz") - subprocess.run( - [ - "wget", - "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz", - ] - ) - subprocess.run( - [ - "tar", - "-xzf", - "./appworld_pack_v3.tar.gz", - "-C", - "/tmp", - ] - ) - # write - os.environ["APPWORLD_PATH"] = "/tmp/pack_all_in_one" - os.environ["APPWORLD_SCRIPT"] = "bash EnvService/env_sandbox/appworld.sh" + +- Run: ``` +export APPWORLD_PATH="/tmp/pack_all_in_one" +export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh" +ajet --with-appworld --skip-check-avail-gpu +``` + ### 2. Prepare AgentScope Workflow diff --git a/tutorial/example_appworld_swarm/appworld_swarm.py b/tutorial/example_appworld_swarm/appworld_swarm.py index 4b1e6c3f..e640291e 100644 --- a/tutorial/example_appworld_swarm/appworld_swarm.py +++ b/tutorial/example_appworld_swarm/appworld_swarm.py @@ -1,7 +1,7 @@ from typing import Any, Tuple -from agentscope.message import Msg from loguru import logger +from openai.types.chat.chat_completion import ChatCompletion from ajet import WorkflowOutput from ajet.schema.task import Task @@ -21,15 +21,6 @@ def __init__(self, env_client: EnvClient, episode_uuid: str): self.episode_uuid = episode_uuid def step(self, action: dict) -> Tuple[Any, float, bool, dict]: - if not isinstance(action["content"], str): - try: - action["content"] = action["content"][0]["text"] - except Exception: - logger.exception( - f"Failed to parse action content from agentscope output. {action['content']}" - ) - action["content"] = str(action["content"]) - env_output = self.env_client.step( instance_id=self.episode_uuid, action=action, @@ -75,10 +66,6 @@ def __init__( self.max_steps = max_steps async def execute(self, task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey) -> WorkflowOutput: - from agentscope.agent import ReActAgent - from agentscope.formatter import DashScopeChatFormatter - from agentscope.memory import InMemoryMemory - episode_uuid = api_baseurl_key.episode_uuid env_client = EnvClient(base_url=self.env_url) @@ -105,35 +92,43 @@ async def execute(self, task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey) -> first_msg = {"content": "You're a helpful assistant."} init_messages = [] - interaction_message = [] + interaction_message = [ + { + "content": first_msg.get("content", "You're a helpful assistant."), + "role": "system", + } + ] for msg in init_messages: interaction_message.append( - Msg( - name=msg.get("name", "user"), - content=msg.get("content", ""), - role=msg.get("role", "user"), - ) + { + "content": msg.get("content", ""), + "role": msg.get("role", "user"), + } ) - agent = ReActAgent( - name="Qwen", - sys_prompt=first_msg.get("content", "You're a helpful assistant."), - model=api_baseurl_key.as_agentscope_model(), - formatter=DashScopeChatFormatter(), - memory=InMemoryMemory(), - toolkit=None, - print_hint_msg=False, - ) - agent.set_console_output_enabled(False) - + client = api_baseurl_key.as_raw_openai_sdk_client() env = AppworldGymWrapper(env_client, episode_uuid) step = 0 for step in range(self.max_steps): - reply_message = await agent(interaction_message) + reply_message: ChatCompletion = await client.chat.completions.create( + model="ajet-model", + messages=interaction_message, + ) obs, _, terminate, _ = env.step( - action={"content": reply_message.content, "role": "assistant"} + action={"content": reply_message.choices[0].message.content, "role": "assistant"} + ) + interaction_message.extend( + [ + { + "content": reply_message.choices[0].message.content, + "role": "assistant", + }, + { + "content": obs, + "role": "user", + } + ] ) - interaction_message = Msg(name="env", content=obs, role="user") if terminate: break diff --git a/tutorial/example_cocktail_rl/cocktail_rl_conf.yaml b/tutorial/example_cocktail_rl/cocktail_rl_conf.yaml new file mode 100644 index 00000000..be4d4078 --- /dev/null +++ b/tutorial/example_cocktail_rl/cocktail_rl_conf.yaml @@ -0,0 +1,103 @@ +# ------------------ Cocktail RL Config ------------------ +# 混合配置:从 client_0 (AppWorld) 和 client_1 (AIME) 中选取最优参数 +# 参数来源标记: [0] = client_0, [1] = client_1 +ajet: + project_name: cocktail_rl + experiment_dir: "auto" # {exp-dir}/{experiment_name} + + task_judge: + judge_protocol: null + + task_reader: + type: random_dummy + + model: + # [0] Qwen2.5-14B-Instruct + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + user_workflow: null + # [1] 不禁用工具调用 + force_disable_toolcalls: False + temperature: 0.9 + # [1] 并行 worker 数量 + max_env_worker: 128 + # [1] GRPO 重复采样次数 + num_repeat: 4 + # [0] Agent 失控惩罚奖励 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 64 + compute_madness_checklist: + - "nonsense" + # [1] 单轮最大响应长度 + max_response_length_in_one_turn: 12000 + # [1] 模型上下文窗口 + max_model_len: 23000 + multi_turn: + max_sample_per_task: 25 + # [0] 多轮交互最大步数 + max_steps: 25 + + enable_interchange_server: True + enable_swarm_mode: True + + interchange_server: + interchange_method: 'ipc' + interchange_server_port: 10086 + num_fastapi_process: 1 + max_fastapi_threads: 512 + max_inference_tracker_threads: 64 + already_started: False + + # [0] 等待所有 client 同意后同步权重 + swarm_mode_sample_collection_method: "rollout_until_all_clients_agree_sync_weight" + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + # [0] batch size + train_batch_size: 64 # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" + max_prompt_length: 3000 + # [1] 最大响应长度 + max_response_length: 20000 + + trainer_common: + # [1] 模型保存频率 (几乎不保存) + save_freq: 1000000000 + # [1] 测试频率 = eval_interval + test_freq: 10 + # [0] 总 epoch 数 + total_epochs: 99999 + # [0] 总训练步数 + total_training_steps: 200 + nnodes: 1 + n_gpus_per_node: 8 + # [1] 日志记录方式 + logger: swanlab + # [0] 验证 pass@n + val_pass_n: 4 + val_before_train: False + + algorithm: + adv_estimator: grpo + + # [0] KL 相关配置 + use_kl_loss: True + use_kl_in_reward: False + kl_penalty_type: kl + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default + - ajet_default + - _self_ diff --git a/tutorial/example_cocktail_rl/config_diff.md b/tutorial/example_cocktail_rl/config_diff.md new file mode 100644 index 00000000..d2384a56 --- /dev/null +++ b/tutorial/example_cocktail_rl/config_diff.md @@ -0,0 +1,13 @@ +# train + +```bash + +source .venv/bin/activate +ajet-swarm start + +source .venv/bin/activate +python -m tutorial.example_cocktail_rl.train_appworld_as_swarm_client_0 + +source .venv/bin/activate +python -m tutorial.example_cocktail_rl.train_aime_as_swarm_client_1 +``` \ No newline at end of file diff --git a/tutorial/example_cocktail_rl/train_aime_as_swarm_client_1.py b/tutorial/example_cocktail_rl/train_aime_as_swarm_client_1.py new file mode 100644 index 00000000..23afd7b5 --- /dev/null +++ b/tutorial/example_cocktail_rl/train_aime_as_swarm_client_1.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +""" +AIME Math Swarm Training - Client 1 (Follower) +This client does NOT control training parameters - it only connects to the swarm server +started by client_0 and contributes rollouts. + +python -m tutorial.example_cocktail_rl.train_aime_as_swarm_client_1 +""" + +import os +import time +import statistics +from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor, as_completed + +from tqdm import tqdm + +from ajet.schema.task import Task +from ajet.task_reader import RouterTaskReader, HuggingFaceTaskReader +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from tutorial.opencode_build_aime.agent_run_v3 import execute_agent +from tutorial.opencode_build_aime import download_data + + +@dataclass +class AgentConfig: + """Minimal config for execute_agent (replaces AgentJetJob for client_1).""" + model: str + max_response_length: int + + +def load_eval_tasks(test_dataset: str, label: str = "") -> list: + eval_tasks = [] + if os.path.exists(test_dataset): + eval_reader = HuggingFaceTaskReader( + AjetTaskReader(huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=test_dataset)) + ) + for t in eval_reader.generate_training_tasks(): + eval_tasks.append(t) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks from {label or test_dataset}") + else: + print(f"[WARN] Eval dataset not found: {test_dataset}. Skipping {label or test_dataset}.") + return eval_tasks + + +class AIMESwarmClient: + """AIME swarm client that follows server started by client_0.""" + + def __init__( + self, + swarm_url: str, + result_dir: str, + max_env_worker: int = 128, + eval_interval: int = 10, + eval_k: int = 4, + grpo_n: int = 4, + ): + self.swarm_url = swarm_url or os.getenv("AJET_SWARM_URL", "http://localhost:10086") + self.result_dir = result_dir + self.max_env_worker = max_env_worker + self.eval_interval = eval_interval + self.eval_k = eval_k + self.grpo_n = grpo_n + + data_dir = os.path.join(os.path.dirname(__file__), "..", "opencode_build_aime", "data") + self.train_dataset = os.path.join(data_dir, "dapo-math-17k.parquet") + self.test_datasets = { + "AIME-2025": os.path.join(data_dir, "aime-2025.parquet"), + "AIME-2026": os.path.join(data_dir, "aime-2026.parquet"), + "DAPO-Math-Tiny-Val": os.path.join(data_dir, "dapo-math-tiny-val.parquet"), + } + + self.swarm_worker: SwarmClient | None = None + self.dataset: RouterTaskReader | None = None + self.eval_tasks_by_set: dict[str, list[Task]] = {} + self.agent_config: AgentConfig | None = None + + os.makedirs(result_dir, exist_ok=True) + + def setup(self): + if not os.path.exists(self.train_dataset): + raise FileNotFoundError( + f"Training dataset not found: {self.train_dataset}\n" + "Please run: proxychains python -m tutorial.opencode_build_aime.download_data" + ) + + self.dataset = RouterTaskReader( + reader_type="huggingface_dat_repo", + reader_config=AjetTaskReader( + huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=self.train_dataset) + ) + ) + + self.swarm_worker = SwarmClient(self.swarm_url, verbose=False) + print("[INFO] Waiting for swarm server to be ready (ENGINE.ROLLING)...") + self.swarm_worker._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + print("[INFO] Swarm server is ready.") + + # Config from env vars (must match server config from client_0) + max_response_length = int(os.getenv("COCKTAIL_MAX_RESPONSE_LENGTH", "20000")) + self.agent_config = AgentConfig(model="dummy", max_response_length=max_response_length) + + # Load eval datasets + eval_downloaders = { + "AIME-2025": download_data.ensure_aime_2025, + "AIME-2026": download_data.ensure_aime_2026, + } + for label, path in self.test_datasets.items(): + if not os.path.exists(path): + downloader = eval_downloaders.get(label) + if downloader is None: + print(f"[WARN] {label} parquet missing at {path} and no downloader registered. Skipping.") + continue + print(f"[INFO] {label} parquet missing, downloading...") + try: + downloader() + except Exception as e: + print(f"[WARN] Failed to download {label}: {e}") + continue + tasks = load_eval_tasks(path, label=label) + if tasks: + self.eval_tasks_by_set[label] = tasks + + def rollout(self, task: Task) -> float: + assert self.swarm_worker is not None and self.agent_config is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode(discard_episode_timeout=120) + workflow_output = execute_agent(task, api_baseurl_key, self.agent_config) + self.swarm_worker.end_episode(task, episode_uuid, workflow_output) + return workflow_output.reward + + def eval_rollout(self, task: Task) -> float: + assert self.swarm_worker is not None and self.agent_config is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=120, episode_type="eval" + ) + try: + workflow_output = execute_agent(task, api_baseurl_key, self.agent_config) + return workflow_output.reward + finally: + self.swarm_worker.abort_episode(episode_uuid) + + def run_eval(self, n_global_step: int): + if not self.eval_tasks_by_set: + return + eval_log_path = os.path.join(self.result_dir, "eval_results.log") + + for label, eval_tasks in self.eval_tasks_by_set.items(): + self._run_eval_one(n_global_step, label, eval_tasks, eval_log_path) + + def _run_eval_one(self, n_global_step: int, label: str, eval_tasks: list, eval_log_path: str): + k = self.eval_k + total_rollouts = len(eval_tasks) * k + print(f"\n[EVAL @ step {n_global_step}] Running {label} eval on {len(eval_tasks)} tasks x {k} (pass@{k})...") + per_task_rewards = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL {label} @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=self.max_env_worker) as eval_executor: + future_to_idx = { + eval_executor.submit(self.eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if flat: + avg = sum(flat) / len(flat) + std_reward = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + pass1 = sum(1 for r in flat if r > 0) / len(flat) + num_all_success_tasks = sum( + 1 for rs in per_task_rewards if rs and all((r is not None and r > 0) for r in rs) + ) + solved_tasks = [rs for rs in per_task_rewards if any((r is not None and r > 0) for r in rs)] + num_pass_n_tasks = len(solved_tasks) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] {label} mean_reward={avg:.4f} std_reward={std_reward:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + + val_result_path = os.path.join(self.result_dir, "val_results.md") + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step}\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std_reward:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") + else: + print(f"[EVAL @ step {n_global_step}] {label} no valid rewards") + + def train(self): + assert self.swarm_worker is not None and self.dataset is not None + + last_eval_step = 0 + # self.run_eval(0) # skip initial eval for faster iteration + + # Use same executor pattern as client_0 for proper weight sync + batch_size = 64 # must match server config + executor = PeriodicDrainThreadPoolExecutor( + workers=self.grpo_n * batch_size, max_parallel=64, auto_retry=True + ) + + train_log_path = os.path.join(self.result_dir, "train_results.log") + + n_global_step = 0 + num_epochs = 10000 + for epoch in range(num_epochs): + for _, task in enumerate(self.dataset.generate_training_tasks()): + for _ in range(self.grpo_n): + _, drained_results = executor.submit_with_periodic_drain( + fn=self.rollout, task=task + ) + if drained_results: + # Log batch rewards before weight sync + rewards = [r for r in drained_results if r is not None] + if rewards: + avg_reward = sum(rewards) / len(rewards) + std_reward = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + success_rate = sum(1 for r in rewards if r > 0) / len(rewards) + step = self.swarm_worker.get_global_step() + log_line = ( + f"[TRAIN @ step {step}] client=aime " + f"batch_size={len(rewards)} mean_reward={avg_reward:.4f} " + f"std_reward={std_reward:.4f} success_rate={success_rate*100:.2f}%" + ) + print(log_line) + with open(train_log_path, "a") as f: + f.write(log_line + "\n") + self.swarm_worker.agree_sync_weight() + + n_global_step = self.swarm_worker.get_global_step() + + if n_global_step >= last_eval_step + self.eval_interval: + self.run_eval(n_global_step) + last_eval_step = n_global_step + + finish_flag = os.path.join(self.result_dir, "finish.flag") + with open(finish_flag, "w") as f: + f.write(f"Training completed at {time.time()}\n") + + print("\n[INFO] Training complete!") + + def run(self): + self.setup() + self.train() + + +def main(): + # Hardcoded config (must match client_0 / cocktail_rl_conf.yaml) + SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + RESULT_DIR = "./cocktail_training_new/results_aime" + MAX_ENV_WORKER = 128 + EVAL_INTERVAL = 10 + EVAL_K = 4 + GRPO_N = 4 + + client = AIMESwarmClient( + swarm_url=SWARM_URL, + result_dir=RESULT_DIR, + max_env_worker=MAX_ENV_WORKER, + eval_interval=EVAL_INTERVAL, + eval_k=EVAL_K, + grpo_n=GRPO_N, + ) + client.run() + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_cocktail_rl/train_appworld_as_swarm_client_0.py b/tutorial/example_cocktail_rl/train_appworld_as_swarm_client_0.py new file mode 100644 index 00000000..4832c943 --- /dev/null +++ b/tutorial/example_cocktail_rl/train_appworld_as_swarm_client_0.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- + +# python -m tutorial.example_cocktail_rl.train_appworld_as_swarm_client_0 + +import os +import statistics +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Generator, List + +from tqdm import tqdm + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from ajet.utils.env_service_client.env_client_ng import EnvClient +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +ENV_URL = os.getenv("APPWORLD_ENV_URL", "http://127.0.0.1:8080") +ENV_TYPE = os.getenv("APPWORLD_ENV_TYPE", "appworld") +TRAINING_SPLIT = os.getenv("APPWORLD_TRAINING_SPLIT", "train") +VALIDATION_SPLIT = os.getenv("APPWORLD_VALIDATION_SPLIT", "dev") +MAX_STEPS = int(os.getenv("APPWORLD_MAX_STEPS", "25")) + +EVAL_INTERVAL = int(os.getenv("APPWORLD_EVAL_INTERVAL", "10")) +EVAL_K = int(os.getenv("APPWORLD_EVAL_K", "1")) +TOTAL_TRAINING_STEPS = int(os.getenv("APPWORLD_TOTAL_TRAINING_STEPS", "200")) +RESULT_DIR = os.getenv("APPWORLD_RESULT_DIR", "./cocktail_training_new/results_appworld") +MAX_ENV_WORKER = int(os.getenv("APPWORLD_MAX_ENV_WORKER", "64")) + + +def get_appworld_tasks(split: str) -> List[Task]: + """Enumerate appworld task ids from env_service for the given split. + + The swarm client owns task generation, so we hit env_service directly + (rather than going through `EnvServiceTaskReader`) to keep the config + surface flat. + """ + env_client = EnvClient(base_url=ENV_URL) + task_id_array = env_client.get_env_profile(ENV_TYPE, split=split) + if len(task_id_array) == 0: + raise ValueError( + f"No task_id found for env_type={ENV_TYPE}, split={split}, " + f"check connection to {ENV_URL}" + ) + return [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=ENV_TYPE, + metadata={}, + ) + for task_id in task_id_array + ] + + +def generate_training_tasks() -> Generator[Task, None, None]: + for task in get_appworld_tasks(TRAINING_SPLIT): + yield task + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + import asyncio + from tutorial.example_appworld_swarm.appworld_swarm import ExampleAgentScopeWorkflow + workflow = ExampleAgentScopeWorkflow( + env_url=ENV_URL, + env_type=ENV_TYPE, + max_steps=MAX_STEPS, + ) + return asyncio.run(workflow.execute(task, api_baseurl_key)) + + +def main(): + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_cocktail_rl/cocktail_rl_conf.yaml", + algorithm="grpo", + experiment_name="cocktail_rl", + max_env_worker=MAX_ENV_WORKER, + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + # force_restart=True, + ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + os.makedirs(RESULT_DIR, exist_ok=True) + eval_log_path = os.path.join(RESULT_DIR, "eval_results.log") + val_result_path = os.path.join(RESULT_DIR, "val_results.md") + + eval_tasks = get_appworld_tasks(VALIDATION_SPLIT) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks (split={VALIDATION_SPLIT})") + + def rollout(task: Task) -> float: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=600) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return workflow_output.reward + + def eval_rollout(task: Task) -> float: + episode_uuid, api_baseurl_key = swarm_worker.begin_episode( + discard_episode_timeout=600, episode_type="eval" + ) + try: + workflow_output = execute_agent(task, api_baseurl_key) + return workflow_output.reward + finally: + # eval samples must NOT be fed back into the training pool + swarm_worker.abort_episode(episode_uuid) + + def run_eval(n_global_step: int): + if not eval_tasks: + return + k = EVAL_K + total_rollouts = len(eval_tasks) * k + print(f"\n[EVAL @ step {n_global_step}] {len(eval_tasks)} tasks x {k} (pass@{k})...") + per_task_rewards: List[List[float]] = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=MAX_ENV_WORKER) as eval_executor: + future_to_idx = { + eval_executor.submit(eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if not flat: + print(f"[EVAL @ step {n_global_step}] no valid rewards") + return + + avg = sum(flat) / len(flat) + std_reward = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + # Full success requires raw_reward >= 1 (final_reward >= 1.5). + # Partial-credit rollouts have 0 < final_reward <= 0.5, so they must NOT + # count as passes; see EnvServiceJudge.compute_reward. + SUCCESS_THRESHOLD = 1.0 + pass1 = sum(1 for r in flat if r >= SUCCESS_THRESHOLD) / len(flat) + num_all_success_tasks = sum( + 1 + for rs in per_task_rewards + if rs and all((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + num_pass_n_tasks = sum( + 1 + for rs in per_task_rewards + if any((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] mean_reward={avg:.4f} std_reward={std_reward:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step}\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std_reward:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") + + # step-0 eval disabled for faster iteration + last_eval_step = 0 + # run_eval(0) # skip initial eval + + executor = PeriodicDrainThreadPoolExecutor( + workers=GRPO_N * REMOTE_BATCH_SIZE, max_parallel=64, auto_retry=True + ) + + train_log_path = os.path.join(RESULT_DIR, "train_results.log") + + n_global_step = 0 + for _ in range(NUM_EPOCH): + for task in generate_training_tasks(): + for _ in range(GRPO_N): + # `submit_with_periodic_drain` returns drained results only when the + # in-flight pool was actually drained on this submission. Each drain + # boundary corresponds to a fully-collected local batch -- exactly + # when this client should agree to a weight sync under + # `rollout_until_all_clients_agree_sync_weight`. + _, drained_results = executor.submit_with_periodic_drain( + fn=rollout, task=task + ) + if drained_results: + # Log batch rewards before weight sync + rewards = [r for r in drained_results if r is not None] + if rewards: + avg_reward = sum(rewards) / len(rewards) + std_reward = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + success_rate = sum(1 for r in rewards if r >= 1.0) / len(rewards) + step = swarm_worker.get_global_step() + log_line = ( + f"[TRAIN @ step {step}] client=appworld " + f"batch_size={len(rewards)} mean_reward={avg_reward:.4f} " + f"std_reward={std_reward:.4f} success_rate={success_rate*100:.2f}%" + ) + print(log_line) + with open(train_log_path, "a") as f: + f.write(log_line + "\n") + swarm_worker.agree_sync_weight() + + n_global_step = swarm_worker.get_global_step() + if n_global_step >= last_eval_step + EVAL_INTERVAL: + run_eval(n_global_step) + last_eval_step = n_global_step + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + print("[INFO] Training complete.") + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py b/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py new file mode 100644 index 00000000..ba0d1d3b --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +""" +Single source of truth for example_cocktail_rl_v2. + +Every config value used anywhere in this tutorial -- v2 schedule knobs, engine +knobs, per-domain knobs -- lives on `CocktailV2Config`. There are no YAMLs, no +hardcoded constants in the runner or clients, no `.get(key, default)` fallback +patterns that could drift. To change anything, edit a default here. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass, field +from typing import List + + +SCHEDULE_TYPES = ("linear", "cos", "constant") + + +# ============================ Per-domain sub-configs ============================ + +@dataclass +class AppWorldConfig: + env_url: str = "http://127.0.0.1:8080" + env_type: str = "appworld" + training_split: str = "train" + validation_split: str = "dev" + episode_timeout: int = 60 + + +@dataclass +class AimeConfig: + episode_timeout: int = 60 + # Filenames resolve under ../opencode_build_aime/data relative to this tutorial. + train_dataset_filename: str = "dapo-math-17k.parquet" + test_dataset_filenames: dict = field(default_factory=lambda: { + "AIME-2026": "aime-2026.parquet", + "DAPO-Math-Tiny-Val": "dapo-math-tiny-val.parquet", + }) + + +# ============================ Top-level config ============================ + +@dataclass +class CocktailV2Config: + """Single source of truth. Both client_0 and client_1 must agree on these + values, so the dataclass defaults ARE the canonical config. + + Schedule semantics for client_0's batch ratio: + schedule_type == "constant": ratio is always `schedule_start`. + schedule_type == "linear": linear from `schedule_start` at step 0 to + `schedule_end` at `schedule_end_step`, + then stays at `schedule_end`. + schedule_type == "cos": cosine anneal from `schedule_start` to + `schedule_end` over `schedule_end_step`, + then stays at `schedule_end`. + client_1's ratio is always 1 - client_0's ratio. + """ + # ---- v2 batching / schedule ---- + total_batch_size: int = 64 + grpo_n: int = 8 + schedule_type: str = "linear" + schedule_start: float = 0.5 + schedule_end: float = 0.0 + schedule_end_step: int = 200 + + # ---- v2 client-side runtime ---- + max_env_worker: int = 64 + eval_interval: int = 10 + eval_k: int = 4 + total_training_steps: int = 200 + swarm_url: str = "http://localhost:10086" + result_dir: str = "./cocktail_results_v2" + + # ---- engine-global per-rollout knobs (read by engine + per-client agents) ---- + max_response_length: int = 20000 + max_steps: int = 25 + + # ---- engine-only knobs (consumed by build_cocktail_ajet_job) ---- + project_name: str = "cocktail_rl" + experiment_name: str = "cocktail_rl_v2" + experiment_dir: str = "auto" + model_path: str = "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B" + algorithm: str = "grpo" + swarm_mode: bool = True + swarm_mode_sample_collection_method: str = "rollout_until_all_clients_agree_sync_weight" + logging: str = "swanlab" + compute_madness_checklist: List[str] = field(default_factory=lambda: ["nonsense"]) + max_prompt_length: int = 3000 + max_response_length_in_one_turn: int = 12000 + max_model_len: int = 23000 + max_num_seqs: int = 64 + n_gpu: int = 8 + use_kl_loss: bool = True + use_kl_in_reward: bool = False + kl_penalty_type: str = "kl" + + # ---- engine knobs not exposed as AgentJetJob kwargs ---- + temperature: float = 0.9 + force_disable_toolcalls: bool = False + agent_madness_reward: float = -1.0 + tensor_model_parallel_size: int = 1 + multi_turn_max_sample_per_task: int = 25 + save_freq: int = 1_000_000_000 + test_freq: int = 10 + total_epochs: int = 99_999 + nnodes: int = 1 + val_pass_n: int = 4 + val_before_train: bool = False + debug_max_parallel: int = 1 + debug_first_n_tasks: int = 1 + + # ---- per-domain ---- + appworld: AppWorldConfig = field(default_factory=AppWorldConfig) + aime: AimeConfig = field(default_factory=AimeConfig) + + def __post_init__(self) -> None: + assert self.total_batch_size >= 1, "total_batch_size must be >= 1" + assert self.grpo_n >= 1, "grpo_n must be >= 1" + assert self.schedule_type in SCHEDULE_TYPES, \ + f"schedule_type must be one of {SCHEDULE_TYPES}, got {self.schedule_type}" + assert 0.0 <= self.schedule_start <= 1.0, "schedule_start must be in [0, 1]" + assert 0.0 <= self.schedule_end <= 1.0, "schedule_end must be in [0, 1]" + assert self.schedule_end_step >= 0, "schedule_end_step must be >= 0" + + def get_client_0_ratio(self, global_step: int) -> float: + if self.schedule_type == "constant" or self.schedule_end_step <= 0: + return self.schedule_start + if global_step >= self.schedule_end_step: + return self.schedule_end + t = global_step / self.schedule_end_step + if self.schedule_type == "linear": + return self.schedule_start + t * (self.schedule_end - self.schedule_start) + if self.schedule_type == "cos": + cos_factor = 0.5 * (1.0 + math.cos(math.pi * t)) # 1 at t=0, 0 at t=1 + return self.schedule_end + (self.schedule_start - self.schedule_end) * cos_factor + raise ValueError(f"Unknown schedule_type: {self.schedule_type}") + + def split_local_batch_sizes(self, global_step: int) -> tuple[int, int]: + """Return (client_0_local_batch_size, client_1_local_batch_size) -- the + number of distinct prompts each client should contribute this round. + Uses round() on client_0; client_1 = total - client_0. Sum == total exactly.""" + r0 = max(0.0, min(1.0, self.get_client_0_ratio(global_step))) + client_0_local_batch_size = int(round(self.total_batch_size * r0)) + client_1_local_batch_size = self.total_batch_size - client_0_local_batch_size + return client_0_local_batch_size, client_1_local_batch_size + + +def cocktail_v2_config_from_env() -> CocktailV2Config: + """Build the v2 config and apply env-var overrides. + + Currently supported env vars: + COCKTAIL_RATIO_SCHEDULE = linear | cos | constant + Override schedule_type. The same value MUST be exported in both + clients' shells, otherwise the two will compute different per-round + local batch sizes. + """ + cfg = CocktailV2Config() + sched_type = os.getenv("COCKTAIL_RATIO_SCHEDULE") + if sched_type is not None: + cfg.schedule_type = sched_type + # Re-validate since we mutated. + cfg.__post_init__() + print(f"[INFO] env override: COCKTAIL_RATIO_SCHEDULE = {sched_type!r}") + return cfg diff --git a/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py b/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py new file mode 100644 index 00000000..75a2bf38 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" +Shared base class for example_cocktail_rl_v2. + +Each per-domain client (AppWorld / AIME) subclasses CocktailSwarmRunner and +implements four methods (setup_data, rollout, eval_rollout, is_success). +The driver subclass additionally overrides `build_ajet_job()`. The follower +inherits the default (returns None) and waits for the engine to roll. + +All configuration lives in `cocktail_v2_config.CocktailV2Config` -- this file +contains zero config defaults. +""" + +from __future__ import annotations + +import os +import time +import statistics +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional + +from tqdm import tqdm + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import CocktailV2Config + + +class CocktailSwarmRunner(ABC): + ROLE: str = "" # "client_0" | "client_1" + IS_DRIVER: bool = False # whether this client drives engine startup + CLIENT_LABEL: str = "" # e.g. "appworld" | "aime", used in subdir + log lines + EPISODE_TIMEOUT: int = 60 + + def __init__(self, v2_config: CocktailV2Config): + assert self.ROLE in ("client_0", "client_1"), \ + f"subclass must set ROLE; got {self.ROLE!r}" + assert self.CLIENT_LABEL, "subclass must set CLIENT_LABEL" + + self.config = v2_config + self.swarm_worker: Optional[SwarmClient] = None + self.dataset = None # must have generate_training_tasks() method + self.eval_tasks_by_set: dict[str, list[Task]] = {} + + self.client_result_dir = os.path.join( + v2_config.result_dir, f"results_{self.CLIENT_LABEL}" + ) + os.makedirs(self.client_result_dir, exist_ok=True) + + # ---------------- to override ---------------- + + @abstractmethod + def setup_data(self) -> None: + """Populate self.dataset (with generate_training_tasks() method) and self.eval_tasks_by_set.""" + + @abstractmethod + def rollout(self, task: Task) -> float: + """Train rollout: begin_episode -> execute -> end_episode -> return reward.""" + + @abstractmethod + def eval_rollout(self, task: Task) -> float: + """Eval rollout: begin_episode(episode_type='eval') -> execute -> abort_episode.""" + + @abstractmethod + def is_success(self, reward: float) -> bool: + """Domain-specific success threshold for logging.""" + + def build_ajet_job(self) -> Optional[AgentJetJob]: + """Driver-only hook. Return a configured AgentJetJob; followers return None.""" + return None + + # ---------------- shared lifecycle ---------------- + + def setup(self) -> None: + self.swarm_worker = SwarmClient(self.config.swarm_url, verbose=False) + if self.IS_DRIVER: + ajet_job = self.build_ajet_job() + assert ajet_job is not None, f"{type(self).__name__}.build_ajet_job() must return AgentJetJob (IS_DRIVER=True)" + self.swarm_worker.auto_sync_train_config_and_start_engine(ajet_job) + else: + print("[INFO] Waiting for swarm server (ENGINE.ROLLING)...") + self.swarm_worker._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + print("[INFO] Swarm server is ready.") + + self.setup_data() + + def run(self) -> None: + self.setup() + self.train_loop() + + # ---------------- shared training ---------------- + + def _get_local_batch_size(self, step: int) -> int: + client_0_batch, client_1_batch = self.config.split_local_batch_sizes(step) + return client_0_batch if self.ROLE == "client_0" else client_1_batch + + def train_loop(self) -> None: + assert self.swarm_worker is not None and self.dataset is not None + + train_log_path = os.path.join( + self.client_result_dir, f"train_results_{self.CLIENT_LABEL}.log" + ) + last_eval_step = 0 + + num_epochs = 10000 + for epoch in range(num_epochs): + step = self.swarm_worker.get_global_step() + local_batch_size = self._get_local_batch_size(step) + + executor = PeriodicDrainThreadPoolExecutor( + workers=local_batch_size * self.config.grpo_n, + max_parallel=self.config.max_env_worker, + auto_retry=True, + ) + + for _, task in enumerate(self.dataset.generate_training_tasks()): + for _ in range(self.config.grpo_n): + _, drained_results = executor.submit_with_periodic_drain( # ✨✨✨✨ + fn=self.rollout, task=task + ) + if drained_results: + rewards = [r for r in drained_results if r is not None] + step = self.swarm_worker.get_global_step() + if rewards: + avg_reward = sum(rewards) / len(rewards) + std_reward = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + success_rate = sum(1 for r in rewards if self.is_success(r)) / len(rewards) + line = ( + f"[TRAIN @ step {step}] client={self.CLIENT_LABEL} " + f"batch_size={len(rewards)} mean_reward={avg_reward:.4f} " + f"std_reward={std_reward:.4f} success_rate={success_rate*100:.2f}%" + ) + print(line) + with open(train_log_path, "a") as f: + f.write(line + "\n") + + self.swarm_worker.agree_sync_weight() + if step >= last_eval_step + self.config.eval_interval: + self.run_eval(step) + last_eval_step = step + + if step >= self.config.total_training_steps: + break + + executor.shutdown(wait=False) + if self.swarm_worker.get_global_step() >= self.config.total_training_steps: + break + + finish_flag = os.path.join(self.client_result_dir, "finish.flag") + with open(finish_flag, "w") as f: + f.write(f"Training completed at {time.time()}\n") + print(f"[INFO] {self.CLIENT_LABEL} training complete.") + + # ---------------- shared eval ---------------- + + def run_eval(self, n_global_step: int) -> None: + if not self.eval_tasks_by_set: + return + eval_log_path = os.path.join( + self.client_result_dir, f"eval_results_{self.CLIENT_LABEL}.log" + ) + for label, eval_tasks in self.eval_tasks_by_set.items(): + self._run_eval_one(n_global_step, label, eval_tasks, eval_log_path) + + def _run_eval_one( + self, + n_global_step: int, + label: str, + eval_tasks: List[Task], + eval_log_path: str, + ) -> None: + k = self.config.eval_k + total_rollouts = len(eval_tasks) * k + print( + f"\n[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label}: " + f"{len(eval_tasks)} tasks x {k} (pass@{k})..." + ) + per_task_rewards: List[List[float]] = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL {label} @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=self.config.max_env_worker) as eval_executor: + future_to_idx = { + eval_executor.submit(self.eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if not flat: + print(f"[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label} no valid rewards") + return + + avg = sum(flat) / len(flat) + std = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + pass1 = sum(1 for r in flat if self.is_success(r)) / len(flat) + num_all_success_tasks = sum( + 1 + for rs in per_task_rewards + if rs and all((r is not None and self.is_success(r)) for r in rs) + ) + num_pass_n_tasks = sum( + 1 + for rs in per_task_rewards + if any((r is not None and self.is_success(r)) for r in rs) + ) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label} " + f"mean_reward={avg:.4f} std_reward={std:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + + val_result_path = os.path.join( + self.client_result_dir, f"val_results_{self.CLIENT_LABEL}.md" + ) + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step} ({label})\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") diff --git a/tutorial/example_cocktail_rl_v2/readme.md b/tutorial/example_cocktail_rl_v2/readme.md new file mode 100644 index 00000000..568dff8e --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/readme.md @@ -0,0 +1,15 @@ +# example_cocktail_rl_v2 + +Cocktail RL on AppWorld + AIME with configurable per-client batch ratios and an optional dynamic schedule. + +```bash +source .venv/bin/activate && ajet-swarm start + +# Export the SAME COCKTAIL_RATIO_SCHEDULE in both shells (linear | cos | constant). +export COCKTAIL_RATIO_SCHEDULE=constant +source .venv/bin/activate +python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0 # driver +python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1 # follower +``` + +Edit `CocktailV2Config` defaults (cocktail_v2_runner.py) for `total_batch_size`, `schedule_start`/`schedule_end`/`schedule_end_step`. Engine knobs live in `build_cocktail_ajet_job()` (train_appworld_as_swarm_client_0.py). Both clients must agree on these. diff --git a/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py b/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py new file mode 100644 index 00000000..d15cafd7 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +""" +AIME swarm client (follower) for example_cocktail_rl_v2. + +python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1 +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import List + +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo +from ajet.schema.task import Task +from ajet.task_reader import HuggingFaceTaskReader, RouterTaskReader + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import ( + CocktailV2Config, + cocktail_v2_config_from_env, +) +from tutorial.example_cocktail_rl_v2.cocktail_v2_runner import CocktailSwarmRunner +from tutorial.opencode_build_aime import download_data +from tutorial.opencode_build_aime.agent_run_v3 import execute_agent as _execute_aime_agent + + +_THIS_DIR = os.path.dirname(__file__) + + +@dataclass +class _AimeAgentConfig: + """Duck-types the subset of AgentJetJob that execute_agent reads.""" + model: str + max_response_length: int + + +def _load_eval_tasks(test_dataset: str, label: str = "") -> List[Task]: + eval_tasks: List[Task] = [] + if not os.path.exists(test_dataset): + print(f"[WARN] Eval dataset not found: {test_dataset}. Skipping {label or test_dataset}.") + return eval_tasks + + eval_reader = HuggingFaceTaskReader( + AjetTaskReader(huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=test_dataset)) + ) + for t in eval_reader.generate_training_tasks(): + eval_tasks.append(t) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks from {label or test_dataset}") + return eval_tasks + + + + +class AimeRunner(CocktailSwarmRunner): + ROLE = "client_1" + IS_DRIVER = False + CLIENT_LABEL = "aime" + + def __init__(self, v2_config: CocktailV2Config): + super().__init__(v2_config) + am = v2_config.aime + self.EPISODE_TIMEOUT = am.episode_timeout + self.agent_config = _AimeAgentConfig( + model="dummy", + max_response_length=v2_config.max_response_length, + ) + + data_dir = os.path.join(_THIS_DIR, "..", "opencode_build_aime", "data") + self.train_dataset = os.path.join(data_dir, am.train_dataset_filename) + self.test_datasets = { + label: os.path.join(data_dir, fname) + for label, fname in am.test_dataset_filenames.items() + } + + def setup_data(self) -> None: + if not os.path.exists(self.train_dataset): + raise FileNotFoundError( + f"AIME training dataset missing: {self.train_dataset}\n" + "Please run: proxychains python -m tutorial.opencode_build_aime.download_data" + ) + + train_reader = RouterTaskReader( + reader_type="huggingface_dat_repo", + reader_config=AjetTaskReader( + huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=self.train_dataset) + ), + ) + self.dataset = train_reader + + eval_downloaders = { + "AIME-2026": download_data.ensure_aime_2026, + } + for label, path in self.test_datasets.items(): + if not os.path.exists(path): + downloader = eval_downloaders.get(label) + if downloader is None: + print(f"[WARN] {label} parquet missing at {path} and no downloader registered. Skipping.") + continue + print(f"[INFO] {label} parquet missing, downloading...") + try: + downloader() + except Exception as e: + print(f"[WARN] Failed to download {label}: {e}") + continue + tasks = _load_eval_tasks(path, label=label) + if tasks: + self.eval_tasks_by_set[label] = tasks + + def rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT + ) + out = _execute_aime_agent(task, api_baseurl_key, self.agent_config) + self.swarm_worker.end_episode(task, episode_uuid, out) + return out.reward + + def eval_rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT, episode_type="eval" + ) + try: + out = _execute_aime_agent(task, api_baseurl_key, self.agent_config) + return out.reward + finally: + self.swarm_worker.abort_episode(episode_uuid) + + def is_success(self, reward: float) -> bool: + return reward > 0 + + +def main(): + cfg = cocktail_v2_config_from_env() + runner = AimeRunner(cfg) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py b/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py new file mode 100644 index 00000000..e7d2d124 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +""" +AppWorld swarm client (driver) for example_cocktail_rl_v2. + +python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0 +""" + +from __future__ import annotations + +import os +import random +from typing import Iterator, List, Optional + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.utils.env_service_client.env_client_ng import EnvClient + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import ( + CocktailV2Config, + cocktail_v2_config_from_env, +) +from tutorial.example_cocktail_rl_v2.cocktail_v2_runner import CocktailSwarmRunner + + +# ---------------- Engine config (was cocktail_rl_conf.yaml) ---------------- + +def build_cocktail_ajet_job(cfg: CocktailV2Config) -> AgentJetJob: + """Construct the AgentJetJob that drives the swarm engine. + + Every value is read from `cfg`. There are no hardcoded constants in this + function -- CocktailV2Config is the single source of truth for the entire + engine config. Fields not exposed as AgentJetJob kwargs are set on + `ajet_job.config.ajet.*` after construction and shipped to the engine via + `Config.to_dict()`. + """ + ajet_job = AgentJetJob( + # base_yaml_config=None -> use ajet/default_config/ajet_swarm_default.yaml + project_name=cfg.project_name, + experiment_name=cfg.experiment_name, + experiment_dir=cfg.experiment_dir, + model=cfg.model_path, + algorithm=cfg.algorithm, + num_repeat=cfg.grpo_n, + # batch_size is ignored under rollout_until_all_clients_agree_sync_weight, + # but we mirror cfg.total_batch_size so the dumped engine config reads coherently. + batch_size=cfg.total_batch_size, + swarm_mode=cfg.swarm_mode, + swarm_mode_sample_collection_method=cfg.swarm_mode_sample_collection_method, + max_env_worker=cfg.max_env_worker, + max_prompt_length=cfg.max_prompt_length, + max_response_length=cfg.max_response_length, + max_response_length_in_one_turn=cfg.max_response_length_in_one_turn, + max_model_len=cfg.max_model_len, + max_num_seqs=cfg.max_num_seqs, + compute_madness_checklist=list(cfg.compute_madness_checklist), + n_gpu=cfg.n_gpu, + logging=cfg.logging, + use_kl_loss=cfg.use_kl_loss, + use_kl_in_reward=cfg.use_kl_in_reward, + kl_penalty_type=cfg.kl_penalty_type, + total_training_steps=cfg.total_training_steps, + ) + + # Fields not exposed as AgentJetJob kwargs. + rollout = ajet_job.config.ajet.rollout + rollout.temperature = cfg.temperature + rollout.force_disable_toolcalls = cfg.force_disable_toolcalls + rollout.agent_madness_reward = cfg.agent_madness_reward + rollout.tensor_model_parallel_size = cfg.tensor_model_parallel_size + rollout.multi_turn = { + "max_sample_per_task": cfg.multi_turn_max_sample_per_task, + "max_steps": cfg.max_steps, + } + + trainer = ajet_job.config.ajet.trainer_common + trainer.save_freq = cfg.save_freq + trainer.test_freq = cfg.test_freq + trainer.total_epochs = cfg.total_epochs + trainer.nnodes = cfg.nnodes + trainer.val_pass_n = cfg.val_pass_n + trainer.val_before_train = cfg.val_before_train + + ajet_job.config.ajet.debug = { + "debug_max_parallel": cfg.debug_max_parallel, + "debug_first_n_tasks": cfg.debug_first_n_tasks, + } + + return ajet_job + + +# ---------------- AppWorld task / runner glue ---------------- + +def _get_appworld_tasks(env_url: str, env_type: str, split: str) -> List[Task]: + env_client = EnvClient(base_url=env_url) + task_id_array = env_client.get_env_profile(env_type, split=split) + if len(task_id_array) == 0: + raise ValueError( + f"No task_id found for env_type={env_type}, split={split}, " + f"check connection to {env_url}" + ) + return [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=env_type, + metadata={}, + ) + for task_id in task_id_array + ] + + +class ShuffledTaskDataset: + def __init__(self, tasks: List[Task]): + self.tasks = list(tasks) + + def generate_training_tasks(self) -> Iterator[Task]: + pool = list(self.tasks) + random.shuffle(pool) + for t in pool: + yield t + + +class AppWorldRunner(CocktailSwarmRunner): + ROLE = "client_0" + IS_DRIVER = True + CLIENT_LABEL = "appworld" + + def __init__(self, v2_config: CocktailV2Config): + super().__init__(v2_config) + ap = v2_config.appworld + self.env_url: str = ap.env_url + self.env_type: str = ap.env_type + self.training_split: str = ap.training_split + self.validation_split: str = ap.validation_split + self.max_steps: int = v2_config.max_steps + self.EPISODE_TIMEOUT = ap.episode_timeout + + def build_ajet_job(self) -> Optional[AgentJetJob]: + return build_cocktail_ajet_job(self.config) + + def setup_data(self) -> None: + train_tasks = _get_appworld_tasks(self.env_url, self.env_type, self.training_split) + print(f"[INFO] AppWorld training: {len(train_tasks)} tasks (split={self.training_split})") + self.dataset = ShuffledTaskDataset(train_tasks) + + eval_tasks = _get_appworld_tasks(self.env_url, self.env_type, self.validation_split) + print(f"[INFO] AppWorld eval: {len(eval_tasks)} tasks (split={self.validation_split})") + self.eval_tasks_by_set = {self.validation_split: eval_tasks} + + def rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT + ) + out = self._execute(task, api_baseurl_key) + self.swarm_worker.end_episode(task, episode_uuid, out) + return out.reward + + def eval_rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT, episode_type="eval" + ) + try: + out = self._execute(task, api_baseurl_key) + return out.reward + finally: + self.swarm_worker.abort_episode(episode_uuid) + + def is_success(self, reward: float) -> bool: + # Mirrors EnvServiceJudge partial-credit shaping: full success requires + # raw_reward >= 1, which corresponds to final_reward >= 1.0 here. + return reward >= 1.0 + + def _execute(self, task: Task, api_baseurl_key): + import asyncio + from tutorial.example_appworld_swarm.appworld_swarm import ExampleAgentScopeWorkflow + wf = ExampleAgentScopeWorkflow( + env_url=self.env_url, + env_type=self.env_type, + max_steps=self.max_steps, + ) + return asyncio.run(wf.execute(task, api_baseurl_key)) + + +def main(): + cfg = cocktail_v2_config_from_env() + runner = AppWorldRunner(cfg) + runner.run() + + +if __name__ == "__main__": + main()