Skip to content
2 changes: 0 additions & 2 deletions QEfficient/cloud/finetune_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,7 @@ def _create_trainer(
# Clean up training config: remove fields that shouldn't be passed to TrainingArguments
training_config.pop("device", None)
training_config.pop("log_file_name", None)
# Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config
training_config.pop("deepspeed_config", None)
training_config.pop("torch_dtype", None)
# Remove PP-specific fields as they're handled via device_map in model loading
training_config.pop("pp_degree", None)

Expand Down
3 changes: 3 additions & 0 deletions QEfficient/finetune/experimental/configs/sft_ddp_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_type: "hf" # Hugging Face model
auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with
model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name
torch_dtype: "float16"
use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning)
peft_config:
lora_r: 8 # LoRA rank
Expand Down Expand Up @@ -42,6 +43,8 @@ dataset:
# Training configuration
training:
type: "sft"
fp16: true
bf16: false
gradient_accumulation_steps: 1 # Number of steps to accumulate gradients
per_device_train_batch_size: 1 # Batch size per device during training
torch_compile: False # Whether to use torch.compile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ model:
model_type: "hf" # Hugging Face model
auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with
model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name
torch_dtype: "float16"
use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning)
peft_config:
lora_r: 16
Expand All @@ -30,6 +31,8 @@ dataset:
# Training configuration
training:
type: "sft"
fp16: true
bf16: false
gradient_accumulation_steps: 2 # Number of steps to accumulate gradients
per_device_train_batch_size: 2 # Batch size per device during training
num_train_epochs: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_type: "hf" # Hugging Face model
auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with
model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name
torch_dtype: "float16"
use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning)
peft_config:
lora_r: 8
Expand All @@ -31,6 +32,8 @@ dataset:
# Training configuration
training:
type: "sft"
fp16: true
bf16: false
gradient_accumulation_steps: 1 # Number of steps to accumulate gradients
per_device_train_batch_size: 1 # Batch size per device during training
num_train_epochs: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ model:
model_type: "hf" # Hugging Face model
auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with
model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name
torch_dtype: "float16"
use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning)
peft_config:
lora_r: 8 # LoRA rank
Expand Down Expand Up @@ -44,6 +45,8 @@ training:
per_device_train_batch_size: 1 # Batch size per device during training
num_train_epochs: 1
torch_compile: False # Whether to use torch.compile
fp16: true
bf16: false

# Optimizer configuration
optimizers:
Expand All @@ -57,3 +60,5 @@ callbacks:
early_stopping:
early_stopping_patience: 3 # Number of epochs to wait before stopping training
early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement


168 changes: 147 additions & 21 deletions QEfficient/finetune/experimental/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch
from transformers import (
DefaultFlowCallback,
EarlyStoppingCallback,
Expand All @@ -28,6 +29,10 @@
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry
from QEfficient.finetune.experimental.core.config_manager import ConfigManager
from QEfficient.finetune.experimental.core.logger import Logger
from QEfficient.finetune.experimental.core.utils.dist_utils import (
get_local_rank,
get_world_size,
)
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
get_op_verifier_ctx,
init_qaic_profiling,
Expand Down Expand Up @@ -272,24 +277,108 @@ class QAICProfilerCallback(TrainerCallback):

def __init__(self, *args, **kwargs):
"""
Initialize QAIC profiler settings (start/end steps and target device IDs).
Initialize QAIC profiler settings (start/end steps, trace directory and target device IDs).
"""

self.start_step = kwargs.get("start_step", -1)
self.end_step = kwargs.get("end_step", -1)
self.device_ids = kwargs.get("device_ids", [0])
if self.start_step >= 0 and self.end_step >= 0 and self.end_step < self.start_step:
raise ValueError(f"end_step ({self.end_step}) must be >= start_step ({self.start_step})")
self.trace_dir = kwargs.get(
"trace_dir",
os.path.join(os.environ.get("OUTPUT_DIR", "."), "hw-trace"),
)
self.device_ids = kwargs.get("device_ids")
self._started_device_ids: list[int] = []
self._profile_started = False
self._stop_attempted = False
self._warned_start_failure = False
self._warned_stop_failure = False

def _resolve_device_ids_for_rank(self) -> list[int]:
local_rank = get_local_rank()
world_size = get_world_size()

if not self.device_ids:
return [local_rank]

ids = list(self.device_ids)

# One device-per-rank mapping.
if world_size > 1 and len(ids) == world_size:
return [ids[local_rank % len(ids)]]

return ids

def _start_profiling(self, device_ids: list[int]) -> None:
started_device_ids: list[int] = []
for device_id in device_ids:
try:
init_qaic_profiling(True, f"qaic:{device_id}", trace_dir=self.trace_dir)
started_device_ids.append(device_id)
except Exception as e:
if not self._warned_start_failure:
logger.log_rank_zero(
f"failed to START profiling on qaic:{device_id}: {e}",
level=logging.WARNING,
)
self._warned_start_failure = True
self._started_device_ids = started_device_ids
self._profile_started = bool(started_device_ids)

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if state.global_step == self.start_step:
for device_id in self.device_ids:
init_qaic_profiling(True, f"qaic:{device_id}")
elif state.global_step == self.end_step:
for device_id in self.device_ids:
def _stop_profiling(self, device_ids: list[int], reason: str) -> None:
for device_id in device_ids:
try:
stop_qaic_profiling(True, f"qaic:{device_id}")
except Exception as e:
if not self._warned_stop_failure:
logger.log_rank_zero(
f"failed to STOP profiling ({reason}) on qaic:{device_id}: {e}",
level=logging.WARNING,
)
self._warned_stop_failure = True
self._started_device_ids = []
self._profile_started = False

# -------------------------------------------------------------------------
# TrainerCallback hooks
# -------------------------------------------------------------------------

def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self.start_step >= 0 and state.global_step == self.start_step and not self._profile_started:
self._start_profiling(self._resolve_device_ids_for_rank())

def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if (
self.end_step >= 0
and state.global_step >= self.end_step
and self._profile_started
and not self._stop_attempted
):
self._stop_attempted = True
self._stop_profiling(self._started_device_ids, f"at end_step={self.end_step}")

def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._profile_started and not self._stop_attempted:
self._stop_attempted = True
self._stop_profiling(self._started_device_ids, "on train end")


@registry.callback("qaic_op_by_op_verifier_callback")
Expand All @@ -300,23 +389,58 @@ def __init__(self, *args, **kwargs):
""" "
Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings.
"""
self.start_step = kwargs.get("start_step", -1)
self.end_step = kwargs.get("end_step", -1)
self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces")
self.atol = kwargs.get("atol", 1e-1)
self.rtol = kwargs.get("rtol", 1e-5)
try:
self.start_step = int(kwargs.get("start_step", -1))
self.end_step = int(kwargs.get("end_step", -1))
self.atol = float(kwargs.get("atol", 1e-1))
self.rtol = float(kwargs.get("rtol", 1e-5))
except (TypeError, ValueError) as e:
raise ValueError(
"qaic_op_by_op_verifier_callback expects numeric values for start_step, end_step, atol, and rtol."
) from e
trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces")
expanded_trace_dir = os.path.expanduser(trace_dir)
if os.path.isabs(expanded_trace_dir):
self.trace_dir = os.path.abspath(expanded_trace_dir)
else:
output_dir = os.environ.get("OUTPUT_DIR", ".")
self.trace_dir = os.path.abspath(os.path.join(output_dir, expanded_trace_dir))
self.op_verifier_ctx_step = None

@staticmethod
def _resolve_ref_dtype(args: TrainingArguments) -> torch.dtype:
# Keep reference dtype aligned with training precision to avoid type mismatches
# in ops that require matching input/grad dtypes (e.g. embedding backward).
if getattr(args, "bf16", False):
return torch.bfloat16
if getattr(args, "fp16", False):
return torch.float16
return torch.float32

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if self.start_step <= state.global_step < self.end_step:
if getattr(args, "fp16", False):
raise RuntimeError(
"qaic_op_by_op_verifier_callback is not supported with fp16/GradScaler training. "
"Set training.fp16=false (and optionally training.bf16=true if supported) "
"when using this callback."
)
logger.log_rank_zero(
"QAIC OpByOp verifier active: "
f"step={state.global_step}, "
f"window=[{self.start_step}, {self.end_step}), "
f"dump_dir={self.trace_dir}/mismatches/step_{state.global_step}"
)
self.op_verifier_ctx_step = get_op_verifier_ctx(
use_op_by_op_verifier=True,
device_type="qaic",
dump_dir=self.trace_dir,
step=state.global_step,
ref_dtype=self._resolve_ref_dtype(args),
atol=self.atol,
rtol=self.rtol,
)
Expand All @@ -327,9 +451,11 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
Event called at the end of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if self.start_step <= state.global_step < self.end_step:
if self.op_verifier_ctx_step is not None:
self.op_verifier_ctx_step.__exit__(None, None, None)
# Always close a previously-entered verifier context, even when the
# post-step global_step moved to end_step.
if self.op_verifier_ctx_step is not None:
self.op_verifier_ctx_step.__exit__(None, None, None)
self.op_verifier_ctx_step = None


def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None:
Expand Down
Loading
Loading