diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py index e54138e4b2..b8b6d9fe5a 100644 --- a/QEfficient/cloud/finetune_experimental.py +++ b/QEfficient/cloud/finetune_experimental.py @@ -275,9 +275,7 @@ def _create_trainer( # Clean up training config: remove fields that shouldn't be passed to TrainingArguments training_config.pop("device", None) training_config.pop("log_file_name", None) - # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config training_config.pop("deepspeed_config", None) - training_config.pop("torch_dtype", None) # Remove PP-specific fields as they're handled via device_map in model loading training_config.pop("pp_degree", None) diff --git a/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml b/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml index 9094ee4b0f..c43ae4b9a5 100644 --- a/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 # LoRA rank @@ -42,6 +43,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training torch_compile: False # Whether to use torch.compile diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml index 2bdf800bc5..e1c435e1c0 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml @@ -9,6 +9,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 16 @@ -30,6 +31,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 2 # Number of steps to accumulate gradients per_device_train_batch_size: 2 # Batch size per device during training num_train_epochs: 1 diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml index fbdcc88d6b..515a208f54 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 @@ -31,6 +32,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training num_train_epochs: 1 diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml index 86c9ec4d13..e9c23c3662 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml @@ -9,6 +9,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 # LoRA rank @@ -44,6 +45,8 @@ training: per_device_train_batch_size: 1 # Batch size per device during training num_train_epochs: 1 torch_compile: False # Whether to use torch.compile + fp16: true + bf16: false # Optimizer configuration optimizers: @@ -57,3 +60,5 @@ callbacks: early_stopping: early_stopping_patience: 3 # Number of epochs to wait before stopping training early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement + + diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 95e5db2b48..673b76cca7 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any, Dict, Optional +import torch from transformers import ( DefaultFlowCallback, EarlyStoppingCallback, @@ -28,6 +29,10 @@ from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry from QEfficient.finetune.experimental.core.config_manager import ConfigManager from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.utils.dist_utils import ( + get_local_rank, + get_world_size, +) from QEfficient.finetune.experimental.core.utils.profiler_utils import ( get_op_verifier_ctx, init_qaic_profiling, @@ -272,24 +277,108 @@ class QAICProfilerCallback(TrainerCallback): def __init__(self, *args, **kwargs): """ - Initialize QAIC profiler settings (start/end steps and target device IDs). + Initialize QAIC profiler settings (start/end steps, trace directory and target device IDs). """ - self.start_step = kwargs.get("start_step", -1) self.end_step = kwargs.get("end_step", -1) - self.device_ids = kwargs.get("device_ids", [0]) + if self.start_step >= 0 and self.end_step >= 0 and self.end_step < self.start_step: + raise ValueError(f"end_step ({self.end_step}) must be >= start_step ({self.start_step})") + self.trace_dir = kwargs.get( + "trace_dir", + os.path.join(os.environ.get("OUTPUT_DIR", "."), "hw-trace"), + ) + self.device_ids = kwargs.get("device_ids") + self._started_device_ids: list[int] = [] + self._profile_started = False + self._stop_attempted = False + self._warned_start_failure = False + self._warned_stop_failure = False + + def _resolve_device_ids_for_rank(self) -> list[int]: + local_rank = get_local_rank() + world_size = get_world_size() + + if not self.device_ids: + return [local_rank] + + ids = list(self.device_ids) + + # One device-per-rank mapping. + if world_size > 1 and len(ids) == world_size: + return [ids[local_rank % len(ids)]] + + return ids + + def _start_profiling(self, device_ids: list[int]) -> None: + started_device_ids: list[int] = [] + for device_id in device_ids: + try: + init_qaic_profiling(True, f"qaic:{device_id}", trace_dir=self.trace_dir) + started_device_ids.append(device_id) + except Exception as e: + if not self._warned_start_failure: + logger.log_rank_zero( + f"failed to START profiling on qaic:{device_id}: {e}", + level=logging.WARNING, + ) + self._warned_start_failure = True + self._started_device_ids = started_device_ids + self._profile_started = bool(started_device_ids) - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - """ - Event called at the beginning of a training step. If using gradient accumulation, one training step might take - several inputs. - """ - if state.global_step == self.start_step: - for device_id in self.device_ids: - init_qaic_profiling(True, f"qaic:{device_id}") - elif state.global_step == self.end_step: - for device_id in self.device_ids: + def _stop_profiling(self, device_ids: list[int], reason: str) -> None: + for device_id in device_ids: + try: stop_qaic_profiling(True, f"qaic:{device_id}") + except Exception as e: + if not self._warned_stop_failure: + logger.log_rank_zero( + f"failed to STOP profiling ({reason}) on qaic:{device_id}: {e}", + level=logging.WARNING, + ) + self._warned_stop_failure = True + self._started_device_ids = [] + self._profile_started = False + + # ------------------------------------------------------------------------- + # TrainerCallback hooks + # ------------------------------------------------------------------------- + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self.start_step >= 0 and state.global_step == self.start_step and not self._profile_started: + self._start_profiling(self._resolve_device_ids_for_rank()) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if ( + self.end_step >= 0 + and state.global_step >= self.end_step + and self._profile_started + and not self._stop_attempted + ): + self._stop_attempted = True + self._stop_profiling(self._started_device_ids, f"at end_step={self.end_step}") + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self._profile_started and not self._stop_attempted: + self._stop_attempted = True + self._stop_profiling(self._started_device_ids, "on train end") @registry.callback("qaic_op_by_op_verifier_callback") @@ -300,11 +389,33 @@ def __init__(self, *args, **kwargs): """ " Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings. """ - self.start_step = kwargs.get("start_step", -1) - self.end_step = kwargs.get("end_step", -1) - self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") - self.atol = kwargs.get("atol", 1e-1) - self.rtol = kwargs.get("rtol", 1e-5) + try: + self.start_step = int(kwargs.get("start_step", -1)) + self.end_step = int(kwargs.get("end_step", -1)) + self.atol = float(kwargs.get("atol", 1e-1)) + self.rtol = float(kwargs.get("rtol", 1e-5)) + except (TypeError, ValueError) as e: + raise ValueError( + "qaic_op_by_op_verifier_callback expects numeric values for start_step, end_step, atol, and rtol." + ) from e + trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") + expanded_trace_dir = os.path.expanduser(trace_dir) + if os.path.isabs(expanded_trace_dir): + self.trace_dir = os.path.abspath(expanded_trace_dir) + else: + output_dir = os.environ.get("OUTPUT_DIR", ".") + self.trace_dir = os.path.abspath(os.path.join(output_dir, expanded_trace_dir)) + self.op_verifier_ctx_step = None + + @staticmethod + def _resolve_ref_dtype(args: TrainingArguments) -> torch.dtype: + # Keep reference dtype aligned with training precision to avoid type mismatches + # in ops that require matching input/grad dtypes (e.g. embedding backward). + if getattr(args, "bf16", False): + return torch.bfloat16 + if getattr(args, "fp16", False): + return torch.float16 + return torch.float32 def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ @@ -312,11 +423,24 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T several inputs. """ if self.start_step <= state.global_step < self.end_step: + if getattr(args, "fp16", False): + raise RuntimeError( + "qaic_op_by_op_verifier_callback is not supported with fp16/GradScaler training. " + "Set training.fp16=false (and optionally training.bf16=true if supported) " + "when using this callback." + ) + logger.log_rank_zero( + "QAIC OpByOp verifier active: " + f"step={state.global_step}, " + f"window=[{self.start_step}, {self.end_step}), " + f"dump_dir={self.trace_dir}/mismatches/step_{state.global_step}" + ) self.op_verifier_ctx_step = get_op_verifier_ctx( use_op_by_op_verifier=True, device_type="qaic", dump_dir=self.trace_dir, step=state.global_step, + ref_dtype=self._resolve_ref_dtype(args), atol=self.atol, rtol=self.rtol, ) @@ -327,9 +451,11 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs. """ - if self.start_step <= state.global_step < self.end_step: - if self.op_verifier_ctx_step is not None: - self.op_verifier_ctx_step.__exit__(None, None, None) + # Always close a previously-entered verifier context, even when the + # post-step global_step moved to end_step. + if self.op_verifier_ctx_step is not None: + self.op_verifier_ctx_step.__exit__(None, None, None) + self.op_verifier_ctx_step = None def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None: diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index f500aeb2c8..3063d207c9 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -265,6 +265,12 @@ class ModelConfig: default=None, metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, ) + torch_dtype: Optional[str] = field( + default="float16", + metadata={ + "help": "Torch dtype passed to model.from_pretrained (e.g., 'float16', 'bfloat16', 'float32', or 'auto')." + }, + ) @dataclass @@ -418,9 +424,13 @@ class TrainingConfig: default="qaic", metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, ) - torch_dtype: str = field( - default="fp16", - metadata={"help": "The torch data type to use for model weights (e.g., 'fp32', 'fp16', 'bf16')."}, + fp16: bool = field( + default=True, + metadata={"help": "Whether to enable fp16 mixed precision in training arguments."}, + ) + bf16: bool = field( + default=False, + metadata={"help": "Whether to enable bf16 mixed precision in training arguments."}, ) torch_compile: bool = field( default=False, @@ -542,11 +552,8 @@ def __init__( logger.log_rank_zero("Using default configuration...") self.config = asdict(self.config) self.config = MasterConfig(**self.config) - # Validate loaded config - try: - self.validate_config() - except Exception as e: - logger.log_rank_zero(f"Config validation failed with error: {e}") + + self.validate_config() def _build_cli_parser(self) -> HfArgumentParser: return HfArgumentParser( @@ -778,9 +785,8 @@ def validate_config(self) -> None: self._push(errors, not model.get("model_name"), "model.model_name is required.") # Device valid_devices = ["cpu", "cuda", "qaic"] - training_device = model.get("device", "qaic") - if training_device not in valid_devices: - self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") + training_device = training.get("device", "qaic") + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") if training_device == "qaic": try: import torch_qaic # noqa: F401 @@ -819,21 +825,43 @@ def validate_config(self) -> None: self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") - # ---------- Training ---------- + # ---------- Model ---------- # torch_dtype validation - torch_dtype = training.get("torch_dtype") - valid_dtypes = {"fp16", "bf16", "fp32"} + torch_dtype = model.get("torch_dtype") + valid_dtypes = {"float16", "bfloat16", "float32", "auto", "fp16", "bf16", "fp32"} self._push( errors, not torch_dtype, - "training.torch_dtype is required.", + "model.torch_dtype is required.", ) self._push( errors, torch_dtype and torch_dtype not in valid_dtypes, - f"training.torch_dtype must be one of {valid_dtypes}.", + f"model.torch_dtype must be one of {valid_dtypes}.", ) + # ---------- Training ---------- + fp16 = bool(training.get("fp16", False)) + bf16 = bool(training.get("bf16", False)) + self._push( + errors, + fp16 and bf16, + "training.fp16 and training.bf16 cannot both be true.", + ) + callbacks_cfg = getattr(cfg, "callbacks", {}) + if isinstance(callbacks_cfg, dict): + callback_dict = {} + nested_callbacks = callbacks_cfg.get("callbacks") + if isinstance(nested_callbacks, dict): + callback_dict.update(nested_callbacks) + callback_dict.update({k: v for k, v in callbacks_cfg.items() if k != "callbacks"}) + self._push( + errors, + fp16 and isinstance(callback_dict, dict) and "qaic_op_by_op_verifier_callback" in callback_dict, + "qaic_op_by_op_verifier_callback is not compatible with training.fp16=true. " + "Set training.fp16=false when using this callback.", + ) + # Batch sizes self._push( errors, @@ -912,21 +940,12 @@ def get_dataset_config(self) -> Dict[str, Any]: def get_model_config(self) -> Dict[str, Any]: """ Get model configuration as dictionary. - - Automatically handles torch_dtype conversion from training config if not set in model config. """ - model_config = self.config.model - - # Get torch_dtype from training config and convert - # To do: check if it can be moved from training config to model config instead - if model_config.get("torch_dtype") is None: - training_config = self.get_training_config() - training_dtype = training_config.get("torch_dtype") - if training_dtype: - # Convert from training format (fp16/bf16) to model format (float16/bfloat16) - dtype_mapping = dtype_mapping = constants.DTYPE_MAPPING - model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto") - + model_config = dict(self.config.model) + dtype_mapping = constants.DTYPE_MAPPING + torch_dtype = model_config.get("torch_dtype") + if torch_dtype in dtype_mapping: + model_config["torch_dtype"] = dtype_mapping[torch_dtype] return model_config def to_dict(self) -> Dict[str, Any]: diff --git a/QEfficient/finetune/experimental/core/utils/constants.py b/QEfficient/finetune/experimental/core/utils/constants.py index 76b3ccd2b1..49aacfe189 100644 --- a/QEfficient/finetune/experimental/core/utils/constants.py +++ b/QEfficient/finetune/experimental/core/utils/constants.py @@ -5,4 +5,4 @@ # # ----------------------------------------------------------------------------- -DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16"} +DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16", "fp32": "float32"} diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py index e24508e831..84de7e2990 100644 --- a/QEfficient/finetune/experimental/core/utils/profiler_utils.py +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -6,8 +6,9 @@ # ----------------------------------------------------------------------------- +import os from contextlib import nullcontext -from typing import ContextManager +from typing import ContextManager, Optional import torch @@ -45,7 +46,7 @@ def get_op_verifier_ctx( Returns: ContextManager: Instance of context manager used to verify the operators. """ - if (not use_op_by_op_verifier) or ("qaic" in device_type): + if (not use_op_by_op_verifier) or ("qaic" not in device_type): return nullcontext() # Lazily imported qaic_debug when it is actually needed. @@ -64,19 +65,26 @@ def get_op_verifier_ctx( ) -def init_qaic_profiling(use_profiler: bool, device_type: str) -> None: +def init_qaic_profiling(use_profiler: bool, device_type: str, trace_dir: Optional[str] = None) -> None: """Initialize the qaic profiling tool. Note: The profiler is only works for qaic backend. Args: use_profiler (bool): Boolean flag to enable profiler. device_type (str): Device on which the model is being executed. + trace_dir (str, optional): Optional output directory for hardware traces. """ if (use_profiler) and ("qaic" in device_type): # Lazily imported qaic's qaic_profile when it is actually needed. import torch_qaic.profile as qaic_profile - qaic_profile.start_profiling(device_type, 1) + if trace_dir is None: + qaic_profile.start_profiling(device_type, 1) + return + + trace_dir = os.path.abspath(os.path.expanduser(trace_dir)) + os.makedirs(trace_dir, exist_ok=True) + qaic_profile.start_profiling(device_type, 1, path=trace_dir) def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None: diff --git a/QEfficient/finetune/experimental/core/utils/training_config_utils.py b/QEfficient/finetune/experimental/core/utils/training_config_utils.py index 1cd6704e44..b002365140 100644 --- a/QEfficient/finetune/experimental/core/utils/training_config_utils.py +++ b/QEfficient/finetune/experimental/core/utils/training_config_utils.py @@ -31,18 +31,8 @@ def prepare_training_config( # Get training config as dict and create mutable copy to avoid mutating original training_config = dict(config_manager.get_training_config()) - # Handle dtype conversion - # To do: (For Tanisha) Check if torch_dtype should rather be added directly in model_config only in config_manager.py - - torch_dtype = training_config.pop("torch_dtype", None) - if torch_dtype is None: - raise ValueError("'torch_dtype' field is required in training configuration. Expected one of: ['fp16', 'bf16']") - training_config[torch_dtype] = True training_config["data_seed"] = training_config.get("seed") - # Restoring the "torch_dtype" after torch_dtype conversion using the saved value - training_config["torch_dtype"] = torch_dtype - # Handle scheduler configuration scheduler_config = config_manager.get_scheduler_config() training_config.setdefault("lr_scheduler_type", scheduler_config.get("scheduler_name")) diff --git a/QEfficient/finetune/experimental/examples/example_config.yaml b/QEfficient/finetune/experimental/examples/example_config.yaml index 809a47ebd1..0b430534a2 100644 --- a/QEfficient/finetune/experimental/examples/example_config.yaml +++ b/QEfficient/finetune/experimental/examples/example_config.yaml @@ -13,6 +13,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 16 @@ -39,6 +40,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 2 # Number of steps to accumulate gradients per_device_train_batch_size: 2 # Batch size per device during training num_train_epochs: 2 diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index e085da9c9e..63b2883a15 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -5,11 +5,26 @@ # # ----------------------------------------------------------------------------- +import os +import shutil +from pathlib import Path +from types import SimpleNamespace + import pytest +import torch from transformers import TrainerCallback +from QEfficient.finetune.experimental.core import callbacks as callbacks_module +from QEfficient.finetune.experimental.core.callbacks import QAICOpByOpVerifierCallback, QAICProfilerCallback from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +PROJECT_ROOT = Path(__file__).resolve().parents[4] +OUTPUT_DIR = PROJECT_ROOT / "training_results" +QAIC_PROFILER_TRACE_DIR = OUTPUT_DIR / "hw-trace" +QAIC_OP_TRACE_DIR = OUTPUT_DIR / "qaic_op_by_op_traces" +QAIC_CUSTOM_OP_TRACE_DIR = OUTPUT_DIR / "custom-op-trace" +QAIC_ABSOLUTE_OP_TRACE_DIR = OUTPUT_DIR / "absolute-op-trace" + class ModelSummaryCallback(TrainerCallback): def __init__(self): @@ -17,18 +32,34 @@ def __init__(self): # Setup test data -CALLBACK_CONFIGS = { - "early_stopping": { - "name": "early_stopping", - "early_stopping_patience": 3, - "early_stopping_threshold": 0.001, - }, - "tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"}, - "model_summary": { - "name": "model_summary", - "max_depth": 1, - }, -} +CALLBACK_CONFIGS = [ + pytest.param( + { + "name": "early_stopping", + "early_stopping_patience": 3, + "early_stopping_threshold": 0.001, + }, + id="early_stopping", + ), + pytest.param({"name": "tensorboard", "tb_writer": "SummaryWriter"}, id="tensorboard"), + pytest.param( + { + "name": "model_summary", + "max_depth": 1, + }, + id="model_summary", + ), + pytest.param( + { + "name": "qaic_profiler_callback", + "start_step": 0, + "end_step": 1, + "trace_dir": str(QAIC_PROFILER_TRACE_DIR), + "device_ids": [0], + }, + id="qaic_profiler", + ), +] REGISTRY_CALLBACK_CONFIGS = { "model_summary": { @@ -39,11 +70,26 @@ def __init__(self): } -@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys()) -def test_callbacks(callback_name): +@pytest.fixture(autouse=True, scope="module") +def setup_training_output_dir(): + prev_output_dir = os.environ.get("OUTPUT_DIR") + shutil.rmtree(OUTPUT_DIR, ignore_errors=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + os.environ["OUTPUT_DIR"] = str(OUTPUT_DIR) + try: + yield + finally: + if prev_output_dir is None: + os.environ.pop("OUTPUT_DIR", None) + else: + os.environ["OUTPUT_DIR"] = prev_output_dir + shutil.rmtree(OUTPUT_DIR, ignore_errors=True) + + +@pytest.mark.parametrize("config", CALLBACK_CONFIGS) +def test_callbacks(config): """Test that registered callbacks that can be created with their configs.""" # Create callbacks using the factory - config = CALLBACK_CONFIGS[callback_name] try: callback_inst = ComponentFactory.create_callback(**config) except ValueError as e: @@ -60,3 +106,286 @@ def test_callbacks_registery(callback_name, callback_class): callback = registry.get_callback(callback_name) assert callback is not None assert callback == callback_class + + +def test_qaic_profiler_uses_user_trace_dir(): + callback = QAICProfilerCallback(trace_dir=str(QAIC_PROFILER_TRACE_DIR)) + assert callback.trace_dir == str(QAIC_PROFILER_TRACE_DIR) + + +def test_qaic_profiler_starts_with_trace_dir(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + + def _mock_start(use_profiler, device_type, trace_dir=None): + calls.append((use_profiler, device_type, trace_dir)) + + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", _mock_start) + + callback = QAICProfilerCallback(start_step=3, end_step=9, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[2]) + state = SimpleNamespace(global_step=3) + + callback.on_step_begin(args=None, state=state, control=None) + + assert callback._profile_started is True + assert calls == [(True, "qaic:2", str(QAIC_PROFILER_TRACE_DIR))] + + +def test_qaic_profiler_stops_once_at_end_step(monkeypatch): + start_calls = [] + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: start_calls.append((use_profiler, device_type, trace_dir)), + ) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=1, end_step=2, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[0]) + + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=1), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=2), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=3), control=None) + + assert len(start_calls) == 1 + assert stop_calls == [(True, "qaic:0")] + assert callback._profile_started is False + + +def test_qaic_profiler_stops_on_train_end_when_not_stopped(monkeypatch): + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", lambda *args, **kwargs: None) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=0, end_step=100, device_ids=[4]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + callback.on_train_end(args=None, state=SimpleNamespace(global_step=1), control=None) + + assert stop_calls == [(True, "qaic:4")] + + +def test_qaic_profiler_uses_local_rank_when_device_ids_not_set(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 3) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 8) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=0, trace_dir=str(QAIC_PROFILER_TRACE_DIR)) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert calls == [(True, "qaic:3", str(QAIC_PROFILER_TRACE_DIR))] + + +def test_qaic_profiler_maps_rank_to_device_id(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 1) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 2) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=5, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[10, 11]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=5), control=None) + + assert calls == [(True, "qaic:11", str(QAIC_PROFILER_TRACE_DIR))] + + +def test_qaic_profiler_invalid_step_range_raises(): + with pytest.raises(ValueError, match="end_step .* must be >= start_step"): + QAICProfilerCallback(start_step=10, end_step=5) + + +def test_qaic_profiler_stops_only_started_devices(monkeypatch): + start_calls = [] + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + + def _mock_start(use_profiler, device_type, trace_dir=None): + start_calls.append((use_profiler, device_type, trace_dir)) + if device_type == "qaic:1": + raise RuntimeError("start failure") + + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", _mock_start) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=0, end_step=1, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[0, 1]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=1), control=None) + + assert start_calls == [ + (True, "qaic:0", str(QAIC_PROFILER_TRACE_DIR)), + (True, "qaic:1", str(QAIC_PROFILER_TRACE_DIR)), + ] + assert stop_calls == [(True, "qaic:0")] + + +def test_qaic_profiler_resolves_rank_at_start_time(monkeypatch): + calls = [] + rank_state = {"local_rank": 0} + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: rank_state["local_rank"]) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 2) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=0, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[10, 11]) + rank_state["local_rank"] = 1 + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert calls == [(True, "qaic:11", str(QAIC_PROFILER_TRACE_DIR))] + + +def test_qaic_op_by_op_verifier_on_step_end_without_initialized_ctx(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=5, trace_dir=str(QAIC_OP_TRACE_DIR)) + state = SimpleNamespace(global_step=1) + + # Should not raise when on_step_end is hit before any context is initialized. + callback.on_step_end(args=None, state=state, control=None) + + +def test_qaic_op_by_op_verifier_default_trace_dir_is_under_output_dir(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1) + assert callback.trace_dir == os.path.abspath(str(OUTPUT_DIR / "qaic_op_by_op_traces")) + + +def test_qaic_op_by_op_verifier_relative_trace_dir_is_under_output_dir(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir="./custom-op-trace") + assert callback.trace_dir == os.path.abspath(str(QAIC_CUSTOM_OP_TRACE_DIR)) + + +def test_qaic_op_by_op_verifier_absolute_trace_dir_is_preserved(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir=str(QAIC_ABSOLUTE_OP_TRACE_DIR)) + assert callback.trace_dir == os.path.abspath(str(QAIC_ABSOLUTE_OP_TRACE_DIR)) + + +def test_qaic_op_by_op_verifier_casts_numeric_config(monkeypatch): + captured = {} + + class _DummyCtx: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def _mock_get_op_verifier_ctx(**kwargs): + captured.update(kwargs) + return _DummyCtx() + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", _mock_get_op_verifier_ctx) + + callback = QAICOpByOpVerifierCallback( + start_step="0", + end_step="2", + trace_dir=str(QAIC_OP_TRACE_DIR), + atol="0.1", + rtol="1e-5", + ) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert isinstance(captured["atol"], float) + assert isinstance(captured["rtol"], float) + assert captured["atol"] == 0.1 + assert captured["rtol"] == 1e-5 + + +@pytest.mark.parametrize( + "args,expected_dtype", + [ + (SimpleNamespace(fp16=False, bf16=True), torch.bfloat16), + (SimpleNamespace(fp16=False, bf16=False), torch.float32), + ], +) +def test_qaic_op_by_op_verifier_uses_training_precision_for_ref_dtype(monkeypatch, args, expected_dtype): + captured = {} + + class _DummyCtx: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def _mock_get_op_verifier_ctx(**kwargs): + captured.update(kwargs) + return _DummyCtx() + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", _mock_get_op_verifier_ctx) + + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) + callback.on_step_begin(args=args, state=SimpleNamespace(global_step=0), control=None) + + assert captured["ref_dtype"] == expected_dtype + + +def test_qaic_op_by_op_verifier_rejects_fp16_mode(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) + + with pytest.raises(RuntimeError, match="not supported with fp16/GradScaler"): + callback.on_step_begin( + args=SimpleNamespace(fp16=True, bf16=False), + state=SimpleNamespace(global_step=0), + control=None, + ) + + +def test_qaic_op_by_op_verifier_exits_ctx_when_global_step_reaches_end_step(monkeypatch): + events = [] + + class _DummyCtx: + def __enter__(self): + events.append("enter") + return self + + def __exit__(self, exc_type, exc, tb): + events.append("exit") + return False + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", lambda **kwargs: _DummyCtx()) + + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) + + # Enter at step 1 (still inside [start_step, end_step) ). + callback.on_step_begin( + args=SimpleNamespace(fp16=False, bf16=False), + state=SimpleNamespace(global_step=1), + control=None, + ) + # At on_step_end, HF Trainer may already report global_step == end_step. + callback.on_step_end(args=None, state=SimpleNamespace(global_step=2), control=None) + + assert events == ["enter", "exit"] + assert callback.op_verifier_ctx_step is None diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml index aab402b483..efe642dbb9 100644 --- a/QEfficient/finetune/experimental/tests/test_config.yaml +++ b/QEfficient/finetune/experimental/tests/test_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" auto_class_name: "AutoModelForCausalLM" model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true peft_config: lora_r: 16 @@ -41,7 +42,8 @@ training: seed: 42 device: "qaic" do_eval: True - torch_dtype: "fp16" + fp16: true + bf16: false eval_strategy: "epoch" eval_steps: 100 per_device_train_batch_size: 1 diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index cc088d0bed..f5c3cb00c3 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -229,26 +229,72 @@ def test_config(config_path): def test_torch_dtype_validation(): """Test that torch_dtype validation works correctly.""" - # Test with default config - should have torch_dtype set to fp16 by default + # Test with default config - should have model torch_dtype set to float16 by default config_manager = ConfigManager() - training_config = config_manager.get_training_config() - assert training_config.get("torch_dtype") == "fp16" + model_config = config_manager.get_model_config() + assert model_config.get("torch_dtype") == "float16" # Validation should pass with default config config_manager.validate_config() # Should not raise -def test_torch_dtype_invalid(): - """Test that invalid torch_dtype raises validation error.""" - from QEfficient.finetune.experimental.core.config_manager import MasterConfig, TrainingConfig +def test_torch_dtype_invalid(monkeypatch): + """Test that invalid torch_dtype is reported via exception or logged validation failure.""" + from QEfficient.finetune.experimental.core import config_manager as config_manager_module + from QEfficient.finetune.experimental.core.config_manager import MasterConfig, ModelConfig + + captured_logs = [] + + def _capture_log(message, level=None): + captured_logs.append((str(message), level)) + + monkeypatch.setattr(config_manager_module.logger, "log_rank_zero", _capture_log) + + # Create config with invalid model torch_dtype + model_config = ModelConfig(torch_dtype="invalid_dtype") + master_config = MasterConfig(model=model_config) + try: + ConfigManager(config=master_config) + except ValueError as exc_info: + assert "torch_dtype must be one of" in str(exc_info) + return + + assert any( + "Config validation failed with error" in msg and "torch_dtype must be one of" in msg for msg, _ in captured_logs + ), "Expected torch_dtype validation failure to be logged when ConfigManager does not raise." + + +def test_fp16_bf16_mutually_exclusive(monkeypatch): + from QEfficient.finetune.experimental.core import config_manager as config_manager_module + + captured_logs = [] + + def _capture_log(message, level=None): + captured_logs.append((str(message), level)) + + monkeypatch.setattr(config_manager_module.logger, "log_rank_zero", _capture_log) + + training_config = TrainingConfig(fp16=True, bf16=True) + master_config = MasterConfig(training=training_config) + try: + ConfigManager(config=master_config) + except ValueError as exc_info: + assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info) + return + + assert any( + "Config validation failed with error" in msg and "training.fp16 and training.bf16 cannot both be true" in msg + for msg, _ in captured_logs + ), "Expected fp16/bf16 mutual-exclusion validation failure to be logged when ConfigManager does not raise." + - # Create config with invalid torch_dtype - training_config = TrainingConfig(torch_dtype="invalid_dtype") +def test_qaic_op_by_op_verifier_disallowed_with_fp16(): + training_config = TrainingConfig(fp16=True, bf16=False) master_config = MasterConfig(training=training_config) config_manager = ConfigManager(config=master_config) + config_manager.update_config({"callbacks": {"qaic_op_by_op_verifier_callback": {"start_step": 0, "end_step": 1}}}) - # Validation should fail with pytest.raises(ValueError) as exc_info: config_manager.validate_config() - assert "torch_dtype must be one of" in str(exc_info.value) + assert "qaic_op_by_op_verifier_callback is not compatible with training.fp16=true" in str(exc_info.value) diff --git a/QEfficient/finetune/experimental/tests/test_integrated.py b/QEfficient/finetune/experimental/tests/test_integrated.py index 207cf458e6..c23029dda0 100644 --- a/QEfficient/finetune/experimental/tests/test_integrated.py +++ b/QEfficient/finetune/experimental/tests/test_integrated.py @@ -22,6 +22,7 @@ import torch from QEfficient.cloud.finetune_experimental import FineTuningPipeline +from QEfficient.finetune.experimental.core import config_manager as config_manager_module from QEfficient.finetune.experimental.core.config_manager import ( ConfigManager, DatasetConfig, @@ -73,6 +74,11 @@ # ============================================================================ +@pytest.fixture(autouse=True) +def bypass_nsp_free_check(monkeypatch): + monkeypatch.setattr(config_manager_module, "is_nsp_free", lambda: None) + + def clean_up(path): if os.path.isdir(path) and os.path.exists(path): shutil.rmtree(path) diff --git a/docs/source/config.md b/docs/source/config.md index 9fc9ecf554..1545710cba 100644 --- a/docs/source/config.md +++ b/docs/source/config.md @@ -18,6 +18,7 @@ Model-related parameters for loading and fine-tuning. | `use_cache` | `false` | Uses the past key/values cache for faster decoding during generation. | | `attn_implementation` | `"sdpa"` | Attention implementation. Common values: `sdpa`, `eager`. | | `device_map` | `None` | Specifies how to distribute the model across devices. | +| `torch_dtype` | `float16` | Torch dtype passed to `model.from_pretrained` (for example `float16`, `bfloat16`, `float32`, `auto`, `fp16`, `bf16`, `fp32`). | | `use_peft` | `true` | Enables PEFT for parameter-efficient fine-tuning. | | `peft_config` | - | Defines LoRA parameters when `use_peft` is true. | @@ -173,7 +174,8 @@ This section defines core parameters for fine-tuning and evaluation. | `do_eval` | `true` | Enables evaluation during training. | | `eval_strategy` | `epoch` | When to run evaluation. | | `gradient_accumulation_steps` | `1` | Accumulates gradients over multiple steps. | -| `dtype` | `fp16` | Mixed precision setting. | +| `fp16` | `true` | Enables fp16 mixed precision in training arguments. | +| `bf16` | `false` | Enables bf16 mixed precision in training arguments. | | `seed` | `42` | Random seed for reproducibility. | | `device` | `qaic` | Device to use for training. | | `per_device_train_batch_size` | `1` | Batch size per device during training. | @@ -210,6 +212,11 @@ This section defines core parameters for fine-tuning and evaluation. Optional distributed configs: FSDP, DeepSpeed, or DDP for multi-QAIC or large-scale training. +Precision note: +- `model.torch_dtype` controls model loading dtype. +- `training.fp16` and `training.bf16` control training-time mixed precision flags. +- `training.fp16` and `training.bf16` cannot both be `true`. + 📁 **Output Directory Structure** output_dir/ @@ -258,6 +265,11 @@ Callbacks allow custom actions during training, such as logging, early stopping, | `QAICProfilerCallback` | Profiles QAIC devices over a specified training step range. | | `QAICOpByOpVerifierCallback` | Verifies QAIC operations step-by-step for correctness and debugging. | +**QAIC callback usage recommendations** + +- Use `qaic_profiler_callback` for only `1-3` steps. +- Use `qaic_op_by_op_verifier_callback` with `training.fp16: false` and `model.torch_dtype: fp32`, for only `1-3` steps. + **References to some commonly used Hugging Face callbacks**: https://huggingface.co/docs/transformers/en/main_classes/callback ***