-
Notifications
You must be signed in to change notification settings - Fork 19
add cocktail training #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,3 +185,5 @@ research_*.json | |
| research_*.jsonc | ||
| daemon_logs* | ||
| paper | ||
| val_results.md | ||
| cocktail_vs_separate* | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from dotenv import load_dotenv | ||
| from loguru import logger | ||
|
|
||
| from ajet.utils.cleaner import fast_kill_by_keyword_bash | ||
| from ajet.utils.config_utils import prepare_experiment_config | ||
| from ajet.utils.launch_utils import ( | ||
| dict_to_namespace, | ||
|
|
@@ -41,6 +42,21 @@ def start_swarm_server(env, config, port): | |
|
|
||
| def cmd_start(args): | ||
| """Handle the 'start' subcommand.""" | ||
| if args.autokill: | ||
| args.kill = "ray|vllm|VLLM|python" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
|
|
||
| if args.kill: | ||
| logger.info(f"Killing processes matching keywords: {args.kill}") | ||
| for keyword in args.kill.split("|"): | ||
| logger.info(f"Killing processes matching keyword: {keyword}") | ||
| killed_pids = fast_kill_by_keyword_bash(keyword) | ||
| if killed_pids: | ||
| logger.success( | ||
| f"Successfully killed processes with PIDs: {killed_pids}" | ||
| ) | ||
| else: | ||
| logger.warning(f"No processes found matching keyword: {keyword}") | ||
|
|
||
| # Use default config if not provided | ||
| exp_base_dir = args.exp_dir or DEFAULT_DIR | ||
| if not args.conf: | ||
|
|
@@ -126,6 +142,19 @@ def main(): | |
| required=False, | ||
| help="Debug tags; enables Ray post-mortem and DEBUG_TAGS env", | ||
| ) | ||
| parser_start.add_argument( | ||
| "--kill", | ||
| type=str, | ||
| default="", | ||
| required=False, | ||
| help="list of keywords for killing processes", | ||
| ) | ||
| parser_start.add_argument( | ||
| "--autokill", | ||
| action="store_true", | ||
| default=False, | ||
| help="Kill system processes (ray + vllm + python) that may block the current experiment", | ||
| ) | ||
|
|
||
| parser_start.set_defaults(func=cmd_start) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,8 +26,9 @@ | |||||||||||||||||
| from ajet.context_tracker.single_agent_tracking import SingleAgentContextTracker | ||||||||||||||||||
| from ajet.tuner_lib.experimental.interchange_utils import ( | ||||||||||||||||||
| http_change_engine_status, | ||||||||||||||||||
| http_update_rollout_pool_information, | ||||||||||||||||||
| http_update_rollout_pool_information_and_fetch_instruction, | ||||||||||||||||||
| CurrentBatchRolloutPoolInformation, | ||||||||||||||||||
| SwarmClientInstruction, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -292,6 +293,14 @@ def rollout_swarm( # noqa: C901 | |||||||||||||||||
| completed_task_id_map_ct: Dict[str, List[SingleAgentContextTracker]] = IterationSafeDict() | ||||||||||||||||||
| executor_lock = threading.Lock() | ||||||||||||||||||
|
|
||||||||||||||||||
| accept_client_control = ("client" in self.config.ajet.swarm_mode_sample_collection_method) | ||||||||||||||||||
| if accept_client_control: | ||||||||||||||||||
| # Latest active-client / agreed-sync-weight snapshot from the swarm server. Refreshed on every pool-information update; | ||||||||||||||||||
| # consumed by the `rollout_until_*_agree_sync_weight` stop conditions. | ||||||||||||||||||
| latest_swarm_client_instructions: Dict[str, SwarmClientInstruction | None] = {"swarm_clients": None} | ||||||||||||||||||
| else: | ||||||||||||||||||
| latest_swarm_client_instructions = None | ||||||||||||||||||
|
|
||||||||||||||||||
| # count tasks to see whether we have reach the finish line for next weight update | ||||||||||||||||||
| def count_tasks(completed_task_id_map_ct): | ||||||||||||||||||
| total_completed_episodes = 0 | ||||||||||||||||||
|
|
@@ -340,6 +349,20 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: | |||||||||||||||||
| completed_task_id_map_ct.clear() | ||||||||||||||||||
| return (total_completed_tasks >= n_batch_task) | ||||||||||||||||||
|
|
||||||||||||||||||
| def any_client_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: | ||||||||||||||||||
| # ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight" | ||||||||||||||||||
| instr = latest_swarm_client_instructions["swarm_clients"] | ||||||||||||||||||
| if instr is None: | ||||||||||||||||||
| return False | ||||||||||||||||||
| return any(c.allowed_sync_weight for c in instr.active_clients) | ||||||||||||||||||
|
|
||||||||||||||||||
| def all_clients_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: | ||||||||||||||||||
| # ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight" | ||||||||||||||||||
| instr = latest_swarm_client_instructions["swarm_clients"] | ||||||||||||||||||
| if instr is None or not instr.active_clients: | ||||||||||||||||||
| return False | ||||||||||||||||||
| return all(c.allowed_sync_weight for c in instr.active_clients) | ||||||||||||||||||
|
|
||||||||||||||||||
| def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: | ||||||||||||||||||
| # ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks" | ||||||||||||||||||
| counts = count_tasks(completed_task_id_map_ct) | ||||||||||||||||||
|
|
@@ -372,6 +395,10 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: | |||||||||||||||||
| stop_condition = enough_finished_task_stop_condition | ||||||||||||||||||
| elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks": | ||||||||||||||||||
| stop_condition = enough_non_dummy_task_stop_condition | ||||||||||||||||||
| elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight": | ||||||||||||||||||
| stop_condition = any_client_agree_sync_weight_stop_condition | ||||||||||||||||||
| elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight": | ||||||||||||||||||
| stop_condition = all_clients_agree_sync_weight_stop_condition | ||||||||||||||||||
| else: | ||||||||||||||||||
| logger.error(f"Invalid swarm_mode_sample_collection_method: {self.config.ajet.swarm_mode_sample_collection_method}, fallback to default method: rollout_until_finish_enough_tasks") | ||||||||||||||||||
| stop_condition = enough_finished_task_stop_condition | ||||||||||||||||||
|
|
@@ -442,9 +469,16 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma | |||||||||||||||||
| buffer += f"Total completed tasks: {counts['total_completed_tasks']} (target {n_batch_task})\n" | ||||||||||||||||||
| buffer += f"Total completed non-dummy tasks: {counts['total_completed_non_dummy_tasks']} (target {n_batch_task})\n" | ||||||||||||||||||
| buffer += f"Current stop condition: {self.config.ajet.swarm_mode_sample_collection_method}\n" | ||||||||||||||||||
| if accept_client_control: | ||||||||||||||||||
| sc_inst = latest_swarm_client_instructions["swarm_clients"] | ||||||||||||||||||
| if sc_inst is not None: | ||||||||||||||||||
| n_active = len(sc_inst.active_clients) | ||||||||||||||||||
| n_agreed = sum(1 for c in sc_inst.active_clients if c.allowed_sync_weight) | ||||||||||||||||||
| buffer += f"Active clients: {n_active} (agreed: {n_agreed})\n" | ||||||||||||||||||
| observation_window["info"][-1] = buffer | ||||||||||||||||||
|
|
||||||||||||||||||
| # Update rollout pool information via API | ||||||||||||||||||
| # Update rollout pool information via API and pull the latest | ||||||||||||||||||
| # active-client / agreed-sync-weight instruction from the server. | ||||||||||||||||||
| pool_info = CurrentBatchRolloutPoolInformation( | ||||||||||||||||||
| sample_collection_method=self.config.ajet.swarm_mode_sample_collection_method, | ||||||||||||||||||
| completed_episodes=counts['total_completed_episodes'], | ||||||||||||||||||
|
|
@@ -457,7 +491,10 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma | |||||||||||||||||
| completed_tasks_details=completed_tasks_details, | ||||||||||||||||||
| completed_tasks_rewards=completed_tasks_rewards, | ||||||||||||||||||
| ) | ||||||||||||||||||
| http_update_rollout_pool_information(self.config, pool_info) | ||||||||||||||||||
| if accept_client_control: | ||||||||||||||||||
| instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) | ||||||||||||||||||
| if instruction is not None: | ||||||||||||||||||
| latest_swarm_client_instructions["swarm_clients"] = instruction | ||||||||||||||||||
|
Comment on lines
+494
to
+497
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The call to
Suggested change
|
||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few typos and grammatical errors in this section that could be corrected for clarity.