Skip to content

add cocktail training#23

Open
binary-husky wants to merge 2 commits intomainfrom
dev
Open

add cocktail training#23
binary-husky wants to merge 2 commits intomainfrom
dev

Conversation

@binary-husky
Copy link
Copy Markdown
Collaborator

No description provided.

… scheduling

- Implemented AppWorld and AIME swarm clients for the new cocktail RL version.
- Introduced CocktailV2Config as a single source of truth for configuration values.
- Created train_appworld_as_swarm_client_0 and train_aime_as_swarm_client_1 scripts for running the respective clients.
- Added cocktail_v2_runner to manage shared functionality between clients.
- Included readme.md for setup instructions and configuration details.
- Enhanced evaluation and logging mechanisms for better performance tracking.
@binary-husky binary-husky changed the title merge dev add cocktail training May 8, 2026
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a swarm-mode feature that allows rollout clients to control training stop conditions through a weight-sync agreement mechanism. It adds two new sample collection methods, infrastructure for tracking active clients on the swarm server, and updates to the overwatch UI for monitoring client status. Additionally, it includes new CLI options for process management and several tutorials for multi-domain RL (Cocktail RL). Feedback focuses on the risks of broad process termination in the autokill feature, potential monitoring breakage due to conditional pool updates, and the need for explicit error handling instead of assertions in API routes.

Comment thread ajet/swarm_cli.py
def cmd_start(args):
"""Handle the 'start' subcommand."""
if args.autokill:
args.kill = "ray|vllm|VLLM|python"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using python as a keyword for autokill is very broad and could unintentionally terminate other important Python processes running on the system, not just those related to the experiment. This could lead to unexpected behavior or data loss. Consider using more specific keywords to target only the intended processes, for example by looking for processes running specific scripts related to ajet.

Comment on lines +494 to +497
if accept_client_control:
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to http_update_rollout_pool_information_and_fetch_instruction is now inside an if accept_client_control: block. This means that for sample collection methods that are not client-controlled (e.g., rollout_until_finish_enough_tasks), the rollout pool information will no longer be updated on the server. This will likely break monitoring tools like the swarm overwatch that rely on this information. The call to update pool information should probably happen regardless of the collection method.

Suggested change
if accept_client_control:
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction
instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info)
if accept_client_control:
if instruction is not None:
latest_swarm_client_instructions["swarm_clients"] = instruction

Comment on lines +187 to +189
- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few typos and grammatical errors in this section that could be corrected for clarity.

Suggested change
- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
- You should create separate tmux sessions for each AgentJet swarm server and each AgentJet swarm client.
- When debugging, please do not restart AgentJet swarm servers frequently, as that wastes a lot of time.
- If you are having difficulty clearing GPU memory, run `ajet --autokill` to automatically kill all Python and Ray processes (however, I still recommend using this as a last resort).

Comment on lines +829 to +834
assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, (
f"agree_sync_weight is only valid when "
f"ajet.swarm_mode_sample_collection_method is one of "
f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently "
f"running with '{pool_info.sample_collection_method}'."
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for request validation can lead to an unhandled exception and a 500 Internal Server Error if the condition is not met. It's better to handle this validation explicitly and return a proper HTTP error, like a 400 Bad Request, with a clear error message for the client.

Suggested change
assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, (
f"agree_sync_weight is only valid when "
f"ajet.swarm_mode_sample_collection_method is one of "
f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently "
f"running with '{pool_info.sample_collection_method}'."
)
if pool_info.sample_collection_method not in AGREE_SYNC_WEIGHT_VALID_METHODS:
return BoolResponse(
success=False,
failure_reason=(
f"agree_sync_weight is only valid when "
f"ajet.swarm_mode_sample_collection_method is one of "
f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently "
f"running with '{pool_info.sample_collection_method}'."
)
)


for _, task in enumerate(self.dataset.generate_training_tasks()):
for _ in range(self.config.grpo_n):
_, drained_results = executor.submit_with_periodic_drain( # ✨✨✨✨
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment contains emojis, which are generally considered unprofessional in source code. Please remove them.

Suggested change
_, drained_results = executor.submit_with_periodic_drain( # ✨✨✨✨
_, drained_results = executor.submit_with_periodic_drain(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant