Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,31 @@ async def _experimental_fork_checkpoint(

shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir)

# Make the fork effective for already-created local services. The
# checkpoint copy alone updates disk, but Unsloth may already have a
# cached trainer and a running vLLM server pointed at the fresh step-0
# adapter.
service = await self._get_service(cast(TrainableModel, model))
if hasattr(service, "_state") and "_state" in service.__dict__:
del service.__dict__["_state"]
if verbose:
print("Invalidated service _state cache for forked checkpoint")
service._forked_checkpoint_dir = dest_checkpoint_dir # type: ignore[attr-defined]

server_started = bool(getattr(service, "_vllm_process", None)) or bool(
getattr(service, "_server_task", None)
)
register_lora = getattr(service, "register_lora_for_step", None)
if server_started and callable(register_lora):
await register_lora(selected_step, dest_checkpoint_dir)
if verbose:
print(
f"Registered forked checkpoint {model.name}@{selected_step} "
"with running inference service"
)
elif hasattr(service, "_latest_step"):
service._latest_step = selected_step # type: ignore[attr-defined]

if verbose:
print(
f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}"
Expand Down
122 changes: 109 additions & 13 deletions src/art/pipeline_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def __init__(
adam_params: object | None = None,
packed_sequence_length: int | None = None,
max_steps: int | None = None,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
kl_window_size: int | None = None,
kl_window_base_step: int = 0,
kl_window_base_adapter_path: str | None = None,
# Discard handling
discard_queue_multiplier: int = 100,
# Status output
Expand All @@ -90,6 +97,7 @@ def __init__(
eval_every_n_steps: int = 20,
eval_at_start: bool = True,
save_checkpoint: bool = True,
save_checkpoint_artifact: bool = False,
# Resumption
resume: bool = True,
) -> None:
Expand All @@ -109,10 +117,20 @@ def __init__(
raise ValueError("eval_every_n_steps must be >= 0")
if max_steps is not None and max_steps < 0:
raise ValueError("max_steps must be >= 0")
if kl_penalty_coef < 0:
raise ValueError("kl_penalty_coef must be >= 0")
if kl_penalty_reference_step is not None and kl_penalty_reference_step < 0:
raise ValueError("kl_penalty_reference_step must be >= 0")
if kl_window_size is not None and kl_window_size < 0:
raise ValueError("kl_window_size must be >= 0")
if kl_window_base_step < 0:
raise ValueError("kl_window_base_step must be >= 0")
if log_interval_seconds <= 0:
raise ValueError("log_interval_seconds must be > 0")
if discard_queue_multiplier <= 0:
raise ValueError("discard_queue_multiplier must be > 0")
if save_checkpoint_artifact and not save_checkpoint:
raise ValueError("save_checkpoint_artifact=True requires save_checkpoint=True")
self.model = model
self.backend = backend
self.rollout_fn = rollout_fn
Expand All @@ -132,10 +150,17 @@ def __init__(
self.adam_params = adam_params
self.packed_sequence_length = packed_sequence_length
self.max_steps = max_steps
self.kl_penalty_coef = kl_penalty_coef
self.kl_penalty_reference_step = kl_penalty_reference_step
self.kl_ref_adapter_path = kl_ref_adapter_path
self.kl_window_size = kl_window_size
self.kl_window_base_step = kl_window_base_step
self.kl_window_base_adapter_path = kl_window_base_adapter_path
self._status_log_interval_seconds = log_interval_seconds
self.eval_every_n_steps = eval_every_n_steps
self.eval_at_start = eval_at_start
self.save_checkpoint = save_checkpoint
self.save_checkpoint_artifact = save_checkpoint_artifact
self.resume = resume
self.discard_queue_multiplier = discard_queue_multiplier
self._discard_queue: list[TrajectoryGroup] = []
Expand Down Expand Up @@ -374,24 +399,32 @@ async def _rollout_worker(self, worker_id: int) -> None:
token = self.model.activate_metrics_context("train")
rollout_started = time.monotonic()
try:
group = await self.rollout_fn(self.model, scenario, self.config)
result = await self.rollout_fn(self.model, scenario, self.config)
finally:
token.var.reset(token)
rollout_wall_s = time.monotonic() - rollout_started
if not isinstance(group, TrajectoryGroup):
groups = result if isinstance(result, list) else [result]
if not groups or not all(
isinstance(group, TrajectoryGroup) for group in groups
):
errored = True
continue
self._apply_scenario_metadata(group, scenario)
self._apply_policy_versions(
group,
initial_version=initial_version,
final_version=self.state.policy_version,
)
if self.state.done:
break
queue_wait_s = await self._put_output_group(group)
group.metadata[_ROLLOUT_WALL_TIME_KEY] = rollout_wall_s
group.metadata[_ACTOR_IDLE_TIME_KEY] = actor_idle_s + queue_wait_s
rollout_wall_per_group = rollout_wall_s / len(groups)
actor_idle_per_group = actor_idle_s / len(groups)
for group in groups:
self._apply_scenario_metadata(group, scenario)
self._apply_policy_versions(
group,
initial_version=initial_version,
final_version=self.state.policy_version,
)
if self.state.done:
break
queue_wait_s = await self._put_output_group(group)
group.metadata[_ROLLOUT_WALL_TIME_KEY] = rollout_wall_per_group
group.metadata[_ACTOR_IDLE_TIME_KEY] = (
actor_idle_per_group + queue_wait_s
)
except asyncio.CancelledError:
raise
except Exception as exc:
Expand Down Expand Up @@ -464,11 +497,24 @@ async def _training_stage(self) -> None:
}
if self.packed_sequence_length is not None:
train_kwargs["packed_sequence_length"] = self.packed_sequence_length
train_kwargs.update(
self._backend_kl_train_kwargs(current_step=current_step)
)
result = await self.backend.train(
self.model,
batch,
**train_kwargs,
)
checkpoint_path = getattr(result, "checkpoint_path", None)
if (
should_checkpoint
and self.save_checkpoint_artifact
and checkpoint_path is not None
):
self._save_checkpoint_artifact(
checkpoint_path=checkpoint_path,
step=result.step,
)
except Exception:
self._status.note_training_end()
raise
Expand Down Expand Up @@ -810,6 +856,53 @@ def _should_eval_step(self, step: int) -> bool:
return False
return (step - self.state.last_eval_step) >= self.eval_every_n_steps

def _backend_kl_train_kwargs(self, *, current_step: int) -> dict[str, Any]:
if self.kl_penalty_coef <= 0:
return {}

kwargs: dict[str, Any] = {"kl_penalty_coef": self.kl_penalty_coef}
if self.kl_ref_adapter_path is not None:
kwargs["kl_ref_adapter_path"] = self.kl_ref_adapter_path
return kwargs

if self.kl_penalty_reference_step is not None:
kwargs["kl_penalty_reference_step"] = self.kl_penalty_reference_step
return kwargs

if self.kl_window_size is None:
return kwargs

if self.kl_window_size == 0:
if self.kl_window_base_adapter_path is not None:
kwargs["kl_ref_adapter_path"] = self.kl_window_base_adapter_path
return kwargs

target_step = current_step - self.kl_window_size
if target_step <= self.kl_window_base_step:
reference_step = self.kl_window_base_step
elif self.eval_every_n_steps <= 0:
reference_step = target_step
else:
window_steps = (target_step - self.kl_window_base_step) // (
self.eval_every_n_steps
)
reference_step = (
self.kl_window_base_step + window_steps * self.eval_every_n_steps
)
kwargs["kl_penalty_reference_step"] = reference_step
return kwargs

def _save_checkpoint_artifact(self, *, checkpoint_path: str, step: int) -> None:
from art.utils.deployment import WandbDeploymentConfig, deploy_wandb

deploy_wandb(
model=self.model,
checkpoint_path=checkpoint_path,
step=step,
config=WandbDeploymentConfig(provenance=["local-rl"]),
verbose=True,
)

def _read_pipeline_state(self) -> dict[str, Any]:
state = self.model.read_state() or {}
return state.get(PIPELINE_STATE_KEY, {})
Expand All @@ -829,6 +922,9 @@ def _is_scalar_metadata(value: object) -> bool:

async def _put_output_group(self, group: TrajectoryGroup) -> float:
assert self._output_queue is not None
if group.metadata and group.metadata.get("skip_training"):
self._status.note_zero_variance_discarded(1)
return 0.0
queue_wait_started = time.monotonic()
while not self.state.done:
try:
Expand Down
3 changes: 2 additions & 1 deletion src/art/pipeline_trainer/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


RolloutFn = Callable[
[art.TrainableModel, ScenarioT, ConfigT], Awaitable[TrajectoryGroup]
[art.TrainableModel, ScenarioT, ConfigT],
Awaitable[TrajectoryGroup | list[TrajectoryGroup]],
]

SingleRolloutFn = Callable[
Expand Down
13 changes: 13 additions & 0 deletions src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class UnslothService:
output_dir: str
_is_sleeping: bool = False
_latest_step: int = 0
_forked_checkpoint_dir: str | None = None
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved
# Dedicated mode subprocess state
_vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg]
Expand Down Expand Up @@ -571,6 +572,14 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
self._latest_step = step
await llm.resume_generation()

async def _load_forked_checkpoint_if_needed(self) -> None:
forked_dir = self._forked_checkpoint_dir
if forked_dir is None:
return

self._forked_checkpoint_dir = None
await self._state.load_lora_adapter(forked_dir)

async def train(
self,
disk_packed_tensors: DiskPackedTensors,
Expand Down Expand Up @@ -598,6 +607,8 @@ async def _train_dedicated(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU."""
await self._load_forked_checkpoint_if_needed()

async for result in run_unsloth_rl_training(
self._state,
disk_packed_tensors=disk_packed_tensors,
Expand Down Expand Up @@ -663,6 +674,8 @@ async def _train_shared(
# Reload training model to GPU (after vLLM is asleep)
self._state.reload_to_gpu()

await self._load_forked_checkpoint_if_needed()

async for result in run_unsloth_rl_training(
self._state,
disk_packed_tensors=disk_packed_tensors,
Expand Down
19 changes: 17 additions & 2 deletions src/art/utils/deployment/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ class WandbDeploymentConfig(DeploymentConfig):
"Qwen/Qwen2.5-14B-Instruct",
]

WANDB_BASE_MODEL_ALIASES = {
"unsloth/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"unsloth/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
}


def get_wandb_base_model(base_model: str) -> str:
"""Return the W&B inference base model id for compatible aliases."""
return WANDB_BASE_MODEL_ALIASES.get(base_model, base_model)


def deploy_wandb(
model: "TrainableModel",
Expand All @@ -54,7 +66,8 @@ def deploy_wandb(
"""
import wandb

if model.base_model not in WANDB_SUPPORTED_BASE_MODELS:
wandb_base_model = get_wandb_base_model(model.base_model)
if wandb_base_model not in WANDB_SUPPORTED_BASE_MODELS:
raise UnsupportedBaseModelDeploymentError(
message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}"
)
Expand All @@ -77,7 +90,9 @@ def deploy_wandb(
settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]),
)
try:
metadata: dict[str, object] = {"wandb.base_model": model.base_model}
metadata: dict[str, object] = {"wandb.base_model": wandb_base_model}
if wandb_base_model != model.base_model:
metadata["source_base_model"] = model.base_model
if config is not None:
metadata["wandb.provenance"] = config.provenance
artifact = wandb.Artifact(
Expand Down
Loading
Loading