Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ swarmexp
swarmlog
werewolves_swarm
.claude
tensorboard_log
tutorial/**/*.json
2 changes: 1 addition & 1 deletion ajet/backbone/main_trinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def patched_trainer_get_actor(cls, config: Config):
Trainer.get_actor = classmethod(patched_trainer_get_actor)

if ajet_config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(ajet_config)


Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def run(self, config):
from ajet.backbone.trainer_verl import AjetRayPPOTrainer

if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(config)

# Initialize the PPO trainer.
Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main(config):
# atexit.register(lambda: print("Process exiting, performing cleanup..."))

if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(config)
if config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def fit(self): # noqa: C901

# # when enabled oai request interchange, we need to clear the cache from time to time
# if self.config.ajet.enable_interchange_server:
# from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
# from ajet.tuner_lib.experimental.oai_model_server import ensure_dat_interchange_server_cache_clear
# ensure_dat_interchange_server_cache_clear()

if is_last_step:
Expand Down
17 changes: 10 additions & 7 deletions ajet/context_tracker/multiagent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,27 @@ def check_context_token_num_safe(
add_generation_prompt=True,
tokenize=False,
)
length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore
max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn
prompt_token_length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore
max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn
max_model_len: int = self.config.ajet.rollout.max_model_len
max_seq_length: int = max_model_len - max_response_length
if length < max_seq_length:
max_seq_length: int = max_model_len - max_response_length_in_one_turn
# prompt_token_length: the prompt_token_length of current all previous context
# max_seq_length: max_model_len - max_response_length_in_one_turn
if prompt_token_length < max_seq_length:
token_overflow = False
else:
token_overflow = True
if self.should_interrupt_soft_fn():
ret = (False, token_overflow, "externally_interrupted")
elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination:
ret = (False, token_overflow, "already_mad")
elif length < max_seq_length:
elif prompt_token_length < max_seq_length:
ret = (
True,
token_overflow,
f"safe[{length} < {max_model_len} - {max_response_length}]",
f"safe[{prompt_token_length} < {max_model_len} - {max_response_length_in_one_turn}]",
)
else:
ret = (False, token_overflow, "token_overflow")
ret = (False, token_overflow,
f"token_overflow(prompt_token_length.{prompt_token_length}>=max_model_len.{max_model_len}-max_response_length_in_one_turn.{max_response_length_in_one_turn})")
return ret
4 changes: 2 additions & 2 deletions ajet/context_tracker/single_agent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def compute_step_level_reward(
def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
result = []
for ext_msg in ext_msg_array:
d = {
d: dict = {
"role": ext_msg.role,
"content": ext_msg.content_for_future,
"content": ext_msg.content_for_compare,
}
if ext_msg.tool_calls:
d.update({"tool_calls": ext_msg.tool_calls})
Expand Down
12 changes: 6 additions & 6 deletions ajet/context_tracker/timeline_merging/timeline_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def is_timeline_mergeable(
for i in range(len(target_timeline)):
if timeline_compare_level == "text":
same = (
source_timeline[i].content_for_future
== target_timeline[i].content_for_future
source_timeline[i].content_for_compare
== target_timeline[i].content_for_compare
)
elif timeline_compare_level == "token":
same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down Expand Up @@ -52,12 +52,12 @@ def is_timeline_mergeable(
# all_msg_match = False
# for i in range(len(target_timeline)):
# d = {}
# d["source"] = source_timeline[i].content_for_future
# d["target"] = target_timeline[i].content_for_future
# d["source"] = source_timeline[i].content_for_compare
# d["target"] = target_timeline[i].content_for_compare
# if timeline_compare_level == "text":
# same = (
# source_timeline[i].content_for_future
# == target_timeline[i].content_for_future
# source_timeline[i].content_for_compare
# == target_timeline[i].content_for_compare
# )
# elif timeline_compare_level == "token":
# same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down
65 changes: 51 additions & 14 deletions ajet/copilot/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,15 @@
import os
import time
import yaml
import tempfile

from types import SimpleNamespace
from typing import Any, Callable, Union, cast
from loguru import logger
from ajet.default_config.ajet_default import Config
from ajet.utils.config_utils import (
expand_ajet_hierarchical_config,
prepare_experiment_config,
read_ajet_hierarchical_config,
)
from ajet.utils.dynamic_import import cls_to_path
from ajet.utils.launch_utils import (
execute_training_process,
check_avail_gpu,
get_backbone_target,
setup_environment_vars,
)


def override_current_yaml_value_if_given(override_value, current_value):
Expand All @@ -48,10 +39,27 @@ def _get_nested_attr(obj, attr_path: str):
return obj

class AgentJetJob:
"""
arg: base_yaml_config + **kwargs (yaml config, then override with kwargs)
arg: base_yaml_config (yaml config)
arg: **kwargs (yaml config, then override with kwargs)
"""Programmatic interface for configuring and launching AgentJet training jobs.

Args:
base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_ts_default.yaml).
experiment_dir: Directory where experiment outputs will be saved.
project_name: Name of the project for organizing experiments.
experiment_name: Unique name for this specific experiment run.
logging: "swanlab", "tensorboard", etc
n_gpu: Number of GPUs to use per node for training.
model: Path or identifier of the model to train.
algorithm: Advantage estimator algorithm (e.g., 'gae', 'vtrace').
num_repeat: Tell swarm server how many repeated sample it should expect for a same task (same means task_id is identical).
batch_size: Training batch size for the model (the watermark to empty buffer pool and update llm weight).
swarm_mode: Whether to enable swarm mode for distributed sample collection.
swarm_mode_sample_collection_method: Method for collecting samples in swarm mode.
max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined).
backbone: Training backbone framework (e.g., 'verl').
max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token).
max_response_length: Maximum token length for model responses (token length after the first llm-generated token).
max_model_len: Maximum total token length (prompt + response) the model can handle (bigger => more GPU memory).
mini_batch_num: Number of mini-batches to split training batch into (how many mini steps, i.e. how many times the `optimizer.step` should be executed, per big train batch).
"""

def __init__(
Expand All @@ -60,6 +68,7 @@ def __init__(
experiment_dir: str | None = None,
project_name: str | None = None,
experiment_name: str | None = None,
logging: str | None = None,
n_gpu: int | None = None,
model: str | None = None,
algorithm: str | None = None,
Expand All @@ -69,20 +78,34 @@ def __init__(
swarm_mode_sample_collection_method: str | None = None,
max_env_worker: int | None = None,
backbone: str | None = None,
max_prompt_length: int | None = None,
max_response_length: int | None = None,
max_response_length_in_one_turn: int | None = None,
max_model_len: int | None = None,
mini_batch_num: int | None = None,
) -> None:

if base_yaml_config is None:
base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
else:
logger.warning(f"Reading config from {base_yaml_config}.")
time.sleep(1)
if not os.path.exists(base_yaml_config):
raise ValueError(f"Configuration yaml is absent! {base_yaml_config}")

# Validate: max_prompt_length, max_response_length, max_model_len must all be None or all be non-None
length_params = [max_prompt_length, max_response_length, max_model_len, max_response_length_in_one_turn]
if not (all(p is None for p in length_params) or all(p is not None for p in length_params)):
raise ValueError("(`max_prompt_length`, `max_response_length`, `max_model_len`, `max_response_length_in_one_turn`) must all be None or all be non-None")

self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config)
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)

self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later
self.experiment_dir: str = cast(str, experiment_dir)
self.project_name: str = cast(str, project_name)
self.experiment_name: str = cast(str, experiment_name)
self.logging: str = cast(str, logging)
self.n_gpu: int = cast(int, n_gpu)
self.model: str = cast(str, model)
self.algorithm: str = cast(str, algorithm)
Expand All @@ -92,12 +115,19 @@ def __init__(
self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method)
self.max_env_worker: int = cast(int, max_env_worker)
self.backbone: str = cast(str, backbone)
self.max_prompt_length: int = cast(int, max_prompt_length)
self.max_response_length_in_one_turn: int = cast(int, max_response_length_in_one_turn)
self.max_response_length: int = cast(int, max_response_length)
self.max_model_len: int = cast(int, max_model_len)
self.mini_batch_num: int = cast(int, mini_batch_num)

# see `ajet/default_config/ajet_ts_default.yaml`
overrides = {
# left: [yaml key navigation] right: [AgentJetJob self attr]
"ajet.experiment_dir": "experiment_dir",
"ajet.project_name": "project_name",
"ajet.experiment_name": "experiment_name",
"ajet.trainer_common.logger": "logging",
"ajet.model.path": "model",
"ajet.trainer_common.n_gpus_per_node": "n_gpu",
"ajet.trainer_common.algorithm.adv_estimator": "algorithm",
Expand All @@ -107,6 +137,11 @@ def __init__(
"ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method",
"ajet.rollout.max_env_worker": "max_env_worker",
"ajet.backbone": "backbone",
"ajet.data.max_prompt_length": "max_prompt_length",
"ajet.data.max_response_length": "max_response_length",
"ajet.rollout.max_response_length_in_one_turn": "max_response_length_in_one_turn",
"ajet.rollout.max_model_len": "max_model_len",
"ajet.trainer_common.mini_batch_num": "mini_batch_num",
}

# if any value given in kwargs, override the corresponding value in config
Expand All @@ -127,6 +162,9 @@ def __init__(
# >> e.g. self.model = new_model
setattr(self, override_val, new_val)


assert self.max_prompt_length + self.max_response_length <= self.max_model_len, "illegal token length"
assert self.max_response_length_in_one_turn <= self.max_response_length
if self.backbone == "trinity":
raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.")

Expand Down Expand Up @@ -184,4 +222,3 @@ def set_data(
)

return self

Loading
Loading