add cocktail training#23
Conversation
… scheduling - Implemented AppWorld and AIME swarm clients for the new cocktail RL version. - Introduced CocktailV2Config as a single source of truth for configuration values. - Created train_appworld_as_swarm_client_0 and train_aime_as_swarm_client_1 scripts for running the respective clients. - Added cocktail_v2_runner to manage shared functionality between clients. - Included readme.md for setup instructions and configuration details. - Enhanced evaluation and logging mechanisms for better performance tracking.
There was a problem hiding this comment.
Code Review
This pull request introduces a swarm-mode feature that allows rollout clients to control training stop conditions through a weight-sync agreement mechanism. It adds two new sample collection methods, infrastructure for tracking active clients on the swarm server, and updates to the overwatch UI for monitoring client status. Additionally, it includes new CLI options for process management and several tutorials for multi-domain RL (Cocktail RL). Feedback focuses on the risks of broad process termination in the autokill feature, potential monitoring breakage due to conditional pool updates, and the need for explicit error handling instead of assertions in API routes.
| def cmd_start(args): | ||
| """Handle the 'start' subcommand.""" | ||
| if args.autokill: | ||
| args.kill = "ray|vllm|VLLM|python" |
There was a problem hiding this comment.
Using python as a keyword for autokill is very broad and could unintentionally terminate other important Python processes running on the system, not just those related to the experiment. This could lead to unexpected behavior or data loss. Consider using more specific keywords to target only the intended processes, for example by looking for processes running specific scripts related to ajet.
| if accept_client_control: | ||
| instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) | ||
| if instruction is not None: | ||
| latest_swarm_client_instructions["swarm_clients"] = instruction |
There was a problem hiding this comment.
The call to http_update_rollout_pool_information_and_fetch_instruction is now inside an if accept_client_control: block. This means that for sample collection methods that are not client-controlled (e.g., rollout_until_finish_enough_tasks), the rollout pool information will no longer be updated on the server. This will likely break monitoring tools like the swarm overwatch that rely on this information. The call to update pool information should probably happen regardless of the collection method.
| if accept_client_control: | |
| instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) | |
| if instruction is not None: | |
| latest_swarm_client_instructions["swarm_clients"] = instruction | |
| instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) | |
| if accept_client_control: | |
| if instruction is not None: | |
| latest_swarm_client_instructions["swarm_clients"] = instruction |
| - You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients | ||
| - When debugging, please do not restart agentjet swarm servers frequently, that waste too much time | ||
| - When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort). |
There was a problem hiding this comment.
There are a few typos and grammatical errors in this section that could be corrected for clarity.
| - You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients | |
| - When debugging, please do not restart agentjet swarm servers frequently, that waste too much time | |
| - When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort). | |
| - You should create separate tmux sessions for each AgentJet swarm server and each AgentJet swarm client. | |
| - When debugging, please do not restart AgentJet swarm servers frequently, as that wastes a lot of time. | |
| - If you are having difficulty clearing GPU memory, run `ajet --autokill` to automatically kill all Python and Ray processes (however, I still recommend using this as a last resort). |
| assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, ( | ||
| f"agree_sync_weight is only valid when " | ||
| f"ajet.swarm_mode_sample_collection_method is one of " | ||
| f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently " | ||
| f"running with '{pool_info.sample_collection_method}'." | ||
| ) |
There was a problem hiding this comment.
Using assert for request validation can lead to an unhandled exception and a 500 Internal Server Error if the condition is not met. It's better to handle this validation explicitly and return a proper HTTP error, like a 400 Bad Request, with a clear error message for the client.
| assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, ( | |
| f"agree_sync_weight is only valid when " | |
| f"ajet.swarm_mode_sample_collection_method is one of " | |
| f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently " | |
| f"running with '{pool_info.sample_collection_method}'." | |
| ) | |
| if pool_info.sample_collection_method not in AGREE_SYNC_WEIGHT_VALID_METHODS: | |
| return BoolResponse( | |
| success=False, | |
| failure_reason=( | |
| f"agree_sync_weight is only valid when " | |
| f"ajet.swarm_mode_sample_collection_method is one of " | |
| f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently " | |
| f"running with '{pool_info.sample_collection_method}'." | |
| ) | |
| ) |
|
|
||
| for _, task in enumerate(self.dataset.generate_training_tasks()): | ||
| for _ in range(self.config.grpo_n): | ||
| _, drained_results = executor.submit_with_periodic_drain( # ✨✨✨✨ |
No description provided.