Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,5 @@ research_*.json
research_*.jsonc
daemon_logs*
paper
val_results.md
cocktail_vs_separate*
8 changes: 8 additions & 0 deletions ajet/copilot/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ class AgentJetJob:
whose ``num_repeat`` episodes do *not* all share the same reward. Tasks with uniform
reward (e.g. all 0 or all 1) produce zero advantage under GRPO and are skipped —
useful when the dataset contains many too-easy or too-hard prompts.
- "rollout_until_any_client_agree_sync_weight": defer the stop decision to the swarm
clients themselves. Stops as soon as **any** active swarm client invokes
``SwarmClient.agree_sync_weight()``. A client is "active" once it has successfully
``end_episode``'d at least one rewarded (non-abort) episode since the last weight
sync, and falls off the active list after 10 minutes of no chat-completion or
``begin_episode`` activity.
- "rollout_until_all_clients_agree_sync_weight": like the above, but stops only when
**every** active swarm client has agreed (and there is at least one active client).
max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined).
backbone: Training backbone framework (e.g., 'verl').
max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token, default 3000).
Expand Down
8 changes: 8 additions & 0 deletions ajet/copilot/monitor-with-tmux/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ $ python3 /tmp/tmux_wait.py ajet_session 240 && tmux capture-pane -t ajet_sessio
tmux kill-session -t ajet_session

```


## For AgentJet Swarm

- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
Comment on lines +187 to +189
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few typos and grammatical errors in this section that could be corrected for clarity.

Suggested change
- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
- You should create separate tmux sessions for each AgentJet swarm server and each AgentJet swarm client.
- When debugging, please do not restart AgentJet swarm servers frequently, as that wastes a lot of time.
- If you are having difficulty clearing GPU memory, run `ajet --autokill` to automatically kill all Python and Ray processes (however, I still recommend using this as a last resort).

- For AgentJet, always use tmux session name that starts with `ajet-*`
1 change: 1 addition & 0 deletions ajet/default_config/ajet_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class AjetModel:
class AjetData:
max_prompt_length: int = 3000
max_response_length: int = 15000
# Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight"
train_batch_size: int = 32


Expand Down
7 changes: 7 additions & 0 deletions ajet/default_config/ajet_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ajet:
# max number of tokens for response
max_response_length: 15000
# how many tasks per training batch
# Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight"
train_batch_size: 32
# [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps)

Expand Down Expand Up @@ -345,6 +346,12 @@ ajet:
# "rollout_until_finish_enough_non_dummy_tasks":
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**.
# (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.)
# "rollout_until_any_client_agree_sync_weight":
# AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`.
# (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync,
# and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.)
# "rollout_until_all_clients_agree_sync_weight":
# Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client).
swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"
swarm_mode_sample_collection_max_cached_episodes: 9999

Expand Down
7 changes: 7 additions & 0 deletions ajet/default_config/ajet_swarm_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ ajet:
# "rollout_until_finish_enough_non_dummy_tasks":
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**.
# (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.)
# "rollout_until_any_client_agree_sync_weight":
# AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`.
# (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync,
# and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.)
# "rollout_until_all_clients_agree_sync_weight":
# Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client).
swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"

data:
Expand All @@ -61,6 +67,7 @@ ajet:
# max number of tokens for response
max_response_length: 15000
# how many tasks per training batch
# Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight"
train_batch_size: 32
# [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps)

Expand Down
29 changes: 29 additions & 0 deletions ajet/swarm_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dotenv import load_dotenv
from loguru import logger

from ajet.utils.cleaner import fast_kill_by_keyword_bash
from ajet.utils.config_utils import prepare_experiment_config
from ajet.utils.launch_utils import (
dict_to_namespace,
Expand Down Expand Up @@ -41,6 +42,21 @@ def start_swarm_server(env, config, port):

def cmd_start(args):
"""Handle the 'start' subcommand."""
if args.autokill:
args.kill = "ray|vllm|VLLM|python"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using python as a keyword for autokill is very broad and could unintentionally terminate other important Python processes running on the system, not just those related to the experiment. This could lead to unexpected behavior or data loss. Consider using more specific keywords to target only the intended processes, for example by looking for processes running specific scripts related to ajet.


if args.kill:
logger.info(f"Killing processes matching keywords: {args.kill}")
for keyword in args.kill.split("|"):
logger.info(f"Killing processes matching keyword: {keyword}")
killed_pids = fast_kill_by_keyword_bash(keyword)
if killed_pids:
logger.success(
f"Successfully killed processes with PIDs: {killed_pids}"
)
else:
logger.warning(f"No processes found matching keyword: {keyword}")

# Use default config if not provided
exp_base_dir = args.exp_dir or DEFAULT_DIR
if not args.conf:
Expand Down Expand Up @@ -126,6 +142,19 @@ def main():
required=False,
help="Debug tags; enables Ray post-mortem and DEBUG_TAGS env",
)
parser_start.add_argument(
"--kill",
type=str,
default="",
required=False,
help="list of keywords for killing processes",
)
parser_start.add_argument(
"--autokill",
action="store_true",
default=False,
help="Kill system processes (ray + vllm + python) that may block the current experiment",
)

parser_start.set_defaults(func=cmd_start)

Expand Down
6 changes: 3 additions & 3 deletions ajet/task_rollout/async_llm_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def llm_chat_verl(
):

parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
parsed_tool_calls = parsed_tool_calls.model_dump()
parsed_tool_calls = parsed_tool_calls.model_dump(mode='json')

model_called = parsed_tool_calls["tools_called"]
if model_called:
Expand Down Expand Up @@ -155,7 +155,7 @@ async def llm_chat_verl(
"completion_tokens": len(token_array), # type: ignore
"total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore
}
# from ajet import bp; bp("DECODE")

return {
"role": "assistant",
"request_id": request_id,
Expand Down Expand Up @@ -327,7 +327,7 @@ async def chat_completion_request(
episode_uuid: str,
):
from openai.types.chat.chat_completion import ChatCompletion
req_as_dict = req.model_dump()
req_as_dict = req.model_dump(mode='json')

# infer + process with context tracker
llm_output = await self.run_infer(
Expand Down
43 changes: 40 additions & 3 deletions ajet/task_rollout/native_parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from ajet.context_tracker.single_agent_tracking import SingleAgentContextTracker
from ajet.tuner_lib.experimental.interchange_utils import (
http_change_engine_status,
http_update_rollout_pool_information,
http_update_rollout_pool_information_and_fetch_instruction,
CurrentBatchRolloutPoolInformation,
SwarmClientInstruction,
)


Expand Down Expand Up @@ -292,6 +293,14 @@ def rollout_swarm( # noqa: C901
completed_task_id_map_ct: Dict[str, List[SingleAgentContextTracker]] = IterationSafeDict()
executor_lock = threading.Lock()

accept_client_control = ("client" in self.config.ajet.swarm_mode_sample_collection_method)
if accept_client_control:
# Latest active-client / agreed-sync-weight snapshot from the swarm server. Refreshed on every pool-information update;
# consumed by the `rollout_until_*_agree_sync_weight` stop conditions.
latest_swarm_client_instructions: Dict[str, SwarmClientInstruction | None] = {"swarm_clients": None}
else:
latest_swarm_client_instructions = None

# count tasks to see whether we have reach the finish line for next weight update
def count_tasks(completed_task_id_map_ct):
total_completed_episodes = 0
Expand Down Expand Up @@ -340,6 +349,20 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
completed_task_id_map_ct.clear()
return (total_completed_tasks >= n_batch_task)

def any_client_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool:
# ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight"
instr = latest_swarm_client_instructions["swarm_clients"]
if instr is None:
return False
return any(c.allowed_sync_weight for c in instr.active_clients)

def all_clients_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool:
# ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight"
instr = latest_swarm_client_instructions["swarm_clients"]
if instr is None or not instr.active_clients:
return False
return all(c.allowed_sync_weight for c in instr.active_clients)

def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
# ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks"
counts = count_tasks(completed_task_id_map_ct)
Expand Down Expand Up @@ -372,6 +395,10 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
stop_condition = enough_finished_task_stop_condition
elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks":
stop_condition = enough_non_dummy_task_stop_condition
elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight":
stop_condition = any_client_agree_sync_weight_stop_condition
elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight":
stop_condition = all_clients_agree_sync_weight_stop_condition
else:
logger.error(f"Invalid swarm_mode_sample_collection_method: {self.config.ajet.swarm_mode_sample_collection_method}, fallback to default method: rollout_until_finish_enough_tasks")
stop_condition = enough_finished_task_stop_condition
Expand Down Expand Up @@ -442,9 +469,16 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
buffer += f"Total completed tasks: {counts['total_completed_tasks']} (target {n_batch_task})\n"
buffer += f"Total completed non-dummy tasks: {counts['total_completed_non_dummy_tasks']} (target {n_batch_task})\n"
buffer += f"Current stop condition: {self.config.ajet.swarm_mode_sample_collection_method}\n"
if accept_client_control:
sc_inst = latest_swarm_client_instructions["swarm_clients"]
if sc_inst is not None:
n_active = len(sc_inst.active_clients)
n_agreed = sum(1 for c in sc_inst.active_clients if c.allowed_sync_weight)
buffer += f"Active clients: {n_active} (agreed: {n_agreed})\n"
observation_window["info"][-1] = buffer

# Update rollout pool information via API
# Update rollout pool information via API and pull the latest
# active-client / agreed-sync-weight instruction from the server.
pool_info = CurrentBatchRolloutPoolInformation(
sample_collection_method=self.config.ajet.swarm_mode_sample_collection_method,
completed_episodes=counts['total_completed_episodes'],
Expand All @@ -457,7 +491,10 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
completed_tasks_details=completed_tasks_details,
completed_tasks_rewards=completed_tasks_rewards,
)
http_update_rollout_pool_information(self.config, pool_info)
if accept_client_control:
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction
Comment on lines +494 to +497
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to http_update_rollout_pool_information_and_fetch_instruction is now inside an if accept_client_control: block. This means that for sample collection methods that are not client-controlled (e.g., rollout_until_finish_enough_tasks), the rollout pool information will no longer be updated on the server. This will likely break monitoring tools like the swarm overwatch that rely on this information. The call to update pool information should probably happen regardless of the collection method.

Suggested change
if accept_client_control:
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if accept_client_control:
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction

return

update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
Expand Down
Loading
Loading