From aa24bb965094b185861f4df79c9e94da9993d922 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Thu, 14 May 2026 03:50:37 +0530 Subject: [PATCH 1/9] adding fix for the qaic profiler callback along with updated test cases Signed-off-by: Sharvari Medhe --- .../finetune/experimental/core/callbacks.py | 110 +++++++++++++++--- .../experimental/core/utils/profiler_utils.py | 14 ++- .../experimental/tests/test_callback.py | 110 ++++++++++++++++++ 3 files changed, 218 insertions(+), 16 deletions(-) diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 95e5db2b48..0e6b5925cc 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -28,6 +28,10 @@ from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry from QEfficient.finetune.experimental.core.config_manager import ConfigManager from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.utils.dist_utils import ( + get_local_rank, + get_world_size, +) from QEfficient.finetune.experimental.core.utils.profiler_utils import ( get_op_verifier_ctx, init_qaic_profiling, @@ -272,24 +276,104 @@ class QAICProfilerCallback(TrainerCallback): def __init__(self, *args, **kwargs): """ - Initialize QAIC profiler settings (start/end steps and target device IDs). + Initialize QAIC profiler settings (start/end steps, trace directory and target device IDs). """ - self.start_step = kwargs.get("start_step", -1) self.end_step = kwargs.get("end_step", -1) - self.device_ids = kwargs.get("device_ids", [0]) + self.trace_dir = kwargs.get( + "trace_dir", + os.path.join(os.environ.get("OUTPUT_DIR", "."), "hw-trace"), + ) + self.device_ids = kwargs.get("device_ids") + self._profile_started = False + self._stop_attempted = False + self._warned_start_failure = False + self._warned_stop_failure = False + + def _resolve_device_ids_for_rank(self) -> list[int]: + local_rank = get_local_rank() + world_size = get_world_size() + + if not self.device_ids: + return [local_rank] + + ids = list(self.device_ids) + + # One device-per-rank mapping. + if world_size > 1 and len(ids) == world_size: + return [ids[local_rank % len(ids)]] + + return ids + + def _start_profiling(self, device_ids: list[int]) -> None: + for device_id in device_ids: + try: + init_qaic_profiling(True, f"qaic:{device_id}", trace_dir=self.trace_dir) + self._profile_started = True + except Exception as e: + if not self._warned_start_failure: + logger.log_rank_zero( + f"failed to START profiling on qaic:{device_id}: {e}", + level=logging.WARNING, + ) + self._warned_start_failure = True - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - """ - Event called at the beginning of a training step. If using gradient accumulation, one training step might take - several inputs. - """ - if state.global_step == self.start_step: - for device_id in self.device_ids: - init_qaic_profiling(True, f"qaic:{device_id}") - elif state.global_step == self.end_step: - for device_id in self.device_ids: + def _stop_profiling(self, device_ids: list[int], reason: str) -> None: + for device_id in device_ids: + try: stop_qaic_profiling(True, f"qaic:{device_id}") + except Exception as e: + if not self._warned_stop_failure: + logger.log_rank_zero( + f"failed to STOP profiling ({reason}) on qaic:{device_id}: {e}", + level=logging.WARNING, + ) + self._warned_stop_failure = True + self._profile_started = False + + # ------------------------------------------------------------------------- + # TrainerCallback hooks + # ------------------------------------------------------------------------- + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self.start_step >= 0 and state.global_step == self.start_step and not self._profile_started: + device_ids = self._resolve_device_ids_for_rank() + self._start_profiling(device_ids) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if ( + self.end_step >= 0 + and state.global_step >= self.end_step + and self._profile_started + and not self._stop_attempted + ): + self._stop_attempted = True + device_ids = self._resolve_device_ids_for_rank() + self._stop_profiling(device_ids, f"at end_step={self.end_step}") + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self._profile_started and not self._stop_attempted: + self._stop_attempted = True + device_ids = self._resolve_device_ids_for_rank() + self._stop_profiling(device_ids, "on train end") @registry.callback("qaic_op_by_op_verifier_callback") diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py index e24508e831..3567870d72 100644 --- a/QEfficient/finetune/experimental/core/utils/profiler_utils.py +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -6,8 +6,9 @@ # ----------------------------------------------------------------------------- +import os from contextlib import nullcontext -from typing import ContextManager +from typing import ContextManager, Optional import torch @@ -64,19 +65,26 @@ def get_op_verifier_ctx( ) -def init_qaic_profiling(use_profiler: bool, device_type: str) -> None: +def init_qaic_profiling(use_profiler: bool, device_type: str, trace_dir: Optional[str] = None) -> None: """Initialize the qaic profiling tool. Note: The profiler is only works for qaic backend. Args: use_profiler (bool): Boolean flag to enable profiler. device_type (str): Device on which the model is being executed. + trace_dir (str, optional): Optional output directory for hardware traces. """ if (use_profiler) and ("qaic" in device_type): # Lazily imported qaic's qaic_profile when it is actually needed. import torch_qaic.profile as qaic_profile - qaic_profile.start_profiling(device_type, 1) + if trace_dir is None: + qaic_profile.start_profiling(device_type, 1) + return + + trace_dir = os.path.abspath(os.path.expanduser(trace_dir)) + os.makedirs(trace_dir, exist_ok=True) + qaic_profile.start_profiling(device_type, 1, path=trace_dir) def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None: diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index e085da9c9e..27cea7f565 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -5,9 +5,13 @@ # # ----------------------------------------------------------------------------- +from types import SimpleNamespace + import pytest from transformers import TrainerCallback +from QEfficient.finetune.experimental.core import callbacks as callbacks_module +from QEfficient.finetune.experimental.core.callbacks import QAICProfilerCallback from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry @@ -60,3 +64,109 @@ def test_callbacks_registery(callback_name, callback_class): callback = registry.get_callback(callback_name) assert callback is not None assert callback == callback_class + + +def test_qaic_profiler_uses_user_trace_dir(): + callback = QAICProfilerCallback(trace_dir="~/my_custom_hw_trace") + assert callback.trace_dir == "~/my_custom_hw_trace" + + +def test_qaic_profiler_starts_with_trace_dir(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + + def _mock_start(use_profiler, device_type, trace_dir=None): + calls.append((use_profiler, device_type, trace_dir)) + + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", _mock_start) + + callback = QAICProfilerCallback(start_step=3, end_step=9, trace_dir="/tmp/hw-trace", device_ids=[2]) + state = SimpleNamespace(global_step=3) + + callback.on_step_begin(args=None, state=state, control=None) + + assert callback._profile_started is True + assert calls == [(True, "qaic:2", "/tmp/hw-trace")] + + +def test_qaic_profiler_stops_once_at_end_step(monkeypatch): + start_calls = [] + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: start_calls.append((use_profiler, device_type, trace_dir)), + ) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=1, end_step=2, trace_dir="/tmp/hw-trace", device_ids=[0]) + + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=1), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=2), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=3), control=None) + + assert len(start_calls) == 1 + assert stop_calls == [(True, "qaic:0")] + assert callback._profile_started is False + + +def test_qaic_profiler_stops_on_train_end_when_not_stopped(monkeypatch): + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", lambda *args, **kwargs: None) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=0, end_step=100, device_ids=[4]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + callback.on_train_end(args=None, state=SimpleNamespace(global_step=1), control=None) + + assert stop_calls == [(True, "qaic:4")] + + +def test_qaic_profiler_uses_local_rank_when_device_ids_not_set(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 3) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 8) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=0, trace_dir="/tmp/hw-trace") + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert calls == [(True, "qaic:3", "/tmp/hw-trace")] + + +def test_qaic_profiler_maps_rank_to_device_id(monkeypatch): + calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 1) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 2) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=5, trace_dir="/tmp/hw-trace", device_ids=[10, 11]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=5), control=None) + + assert calls == [(True, "qaic:11", "/tmp/hw-trace")] From ad25b00ee3bd469c94e6cbc943cdc63660f97b03 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 18 May 2026 17:22:34 +0530 Subject: [PATCH 2/9] adding changes to address qgenie reviews Signed-off-by: Sharvari Medhe --- .../finetune/experimental/core/callbacks.py | 18 ++-- .../experimental/core/utils/profiler_utils.py | 2 +- .../experimental/tests/test_callback.py | 96 ++++++++++++++++--- 3 files changed, 93 insertions(+), 23 deletions(-) diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 0e6b5925cc..0ad7acc40b 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -280,11 +280,14 @@ def __init__(self, *args, **kwargs): """ self.start_step = kwargs.get("start_step", -1) self.end_step = kwargs.get("end_step", -1) + if self.start_step >= 0 and self.end_step >= 0 and self.end_step < self.start_step: + raise ValueError(f"end_step ({self.end_step}) must be >= start_step ({self.start_step})") self.trace_dir = kwargs.get( "trace_dir", os.path.join(os.environ.get("OUTPUT_DIR", "."), "hw-trace"), ) self.device_ids = kwargs.get("device_ids") + self._started_device_ids: list[int] = [] self._profile_started = False self._stop_attempted = False self._warned_start_failure = False @@ -306,10 +309,11 @@ def _resolve_device_ids_for_rank(self) -> list[int]: return ids def _start_profiling(self, device_ids: list[int]) -> None: + started_device_ids: list[int] = [] for device_id in device_ids: try: init_qaic_profiling(True, f"qaic:{device_id}", trace_dir=self.trace_dir) - self._profile_started = True + started_device_ids.append(device_id) except Exception as e: if not self._warned_start_failure: logger.log_rank_zero( @@ -317,6 +321,8 @@ def _start_profiling(self, device_ids: list[int]) -> None: level=logging.WARNING, ) self._warned_start_failure = True + self._started_device_ids = started_device_ids + self._profile_started = bool(started_device_ids) def _stop_profiling(self, device_ids: list[int], reason: str) -> None: for device_id in device_ids: @@ -329,6 +335,7 @@ def _stop_profiling(self, device_ids: list[int], reason: str) -> None: level=logging.WARNING, ) self._warned_stop_failure = True + self._started_device_ids = [] self._profile_started = False # ------------------------------------------------------------------------- @@ -343,8 +350,7 @@ def on_step_begin( **kwargs, ): if self.start_step >= 0 and state.global_step == self.start_step and not self._profile_started: - device_ids = self._resolve_device_ids_for_rank() - self._start_profiling(device_ids) + self._start_profiling(self._resolve_device_ids_for_rank()) def on_step_end( self, @@ -360,8 +366,7 @@ def on_step_end( and not self._stop_attempted ): self._stop_attempted = True - device_ids = self._resolve_device_ids_for_rank() - self._stop_profiling(device_ids, f"at end_step={self.end_step}") + self._stop_profiling(self._started_device_ids, f"at end_step={self.end_step}") def on_train_end( self, @@ -372,8 +377,7 @@ def on_train_end( ): if self._profile_started and not self._stop_attempted: self._stop_attempted = True - device_ids = self._resolve_device_ids_for_rank() - self._stop_profiling(device_ids, "on train end") + self._stop_profiling(self._started_device_ids, "on train end") @registry.callback("qaic_op_by_op_verifier_callback") diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py index 3567870d72..84de7e2990 100644 --- a/QEfficient/finetune/experimental/core/utils/profiler_utils.py +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -46,7 +46,7 @@ def get_op_verifier_ctx( Returns: ContextManager: Instance of context manager used to verify the operators. """ - if (not use_op_by_op_verifier) or ("qaic" in device_type): + if (not use_op_by_op_verifier) or ("qaic" not in device_type): return nullcontext() # Lazily imported qaic_debug when it is actually needed. diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 27cea7f565..854ef7382e 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -21,18 +21,34 @@ def __init__(self): # Setup test data -CALLBACK_CONFIGS = { - "early_stopping": { - "name": "early_stopping", - "early_stopping_patience": 3, - "early_stopping_threshold": 0.001, - }, - "tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"}, - "model_summary": { - "name": "model_summary", - "max_depth": 1, - }, -} +CALLBACK_CONFIGS = [ + pytest.param( + { + "name": "early_stopping", + "early_stopping_patience": 3, + "early_stopping_threshold": 0.001, + }, + id="early_stopping", + ), + pytest.param({"name": "tensorboard", "tb_writer": "SummaryWriter"}, id="tensorboard"), + pytest.param( + { + "name": "model_summary", + "max_depth": 1, + }, + id="model_summary", + ), + pytest.param( + { + "name": "qaic_profiler_callback", + "start_step": 0, + "end_step": 1, + "trace_dir": "/tmp/hw-trace", + "device_ids": [0], + }, + id="qaic_profiler", + ), +] REGISTRY_CALLBACK_CONFIGS = { "model_summary": { @@ -43,11 +59,10 @@ def __init__(self): } -@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys()) -def test_callbacks(callback_name): +@pytest.mark.parametrize("config", CALLBACK_CONFIGS) +def test_callbacks(config): """Test that registered callbacks that can be created with their configs.""" # Create callbacks using the factory - config = CALLBACK_CONFIGS[callback_name] try: callback_inst = ComponentFactory.create_callback(**config) except ValueError as e: @@ -170,3 +185,54 @@ def test_qaic_profiler_maps_rank_to_device_id(monkeypatch): callback.on_step_begin(args=None, state=SimpleNamespace(global_step=5), control=None) assert calls == [(True, "qaic:11", "/tmp/hw-trace")] + + +def test_qaic_profiler_invalid_step_range_raises(): + with pytest.raises(ValueError, match="end_step .* must be >= start_step"): + QAICProfilerCallback(start_step=10, end_step=5) + + +def test_qaic_profiler_stops_only_started_devices(monkeypatch): + start_calls = [] + stop_calls = [] + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: 0) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 1) + + def _mock_start(use_profiler, device_type, trace_dir=None): + start_calls.append((use_profiler, device_type, trace_dir)) + if device_type == "qaic:1": + raise RuntimeError("start failure") + + monkeypatch.setattr(callbacks_module, "init_qaic_profiling", _mock_start) + monkeypatch.setattr( + callbacks_module, + "stop_qaic_profiling", + lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), + ) + + callback = QAICProfilerCallback(start_step=0, end_step=1, trace_dir="/tmp/hw-trace", device_ids=[0, 1]) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + callback.on_step_end(args=None, state=SimpleNamespace(global_step=1), control=None) + + assert start_calls == [(True, "qaic:0", "/tmp/hw-trace"), (True, "qaic:1", "/tmp/hw-trace")] + assert stop_calls == [(True, "qaic:0")] + + +def test_qaic_profiler_resolves_rank_at_start_time(monkeypatch): + calls = [] + rank_state = {"local_rank": 0} + + monkeypatch.setattr(callbacks_module, "get_local_rank", lambda: rank_state["local_rank"]) + monkeypatch.setattr(callbacks_module, "get_world_size", lambda: 2) + monkeypatch.setattr( + callbacks_module, + "init_qaic_profiling", + lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), + ) + + callback = QAICProfilerCallback(start_step=0, trace_dir="/tmp/hw-trace", device_ids=[10, 11]) + rank_state["local_rank"] = 1 + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert calls == [(True, "qaic:11", "/tmp/hw-trace")] From 184fc0cf537a9f4fed4b45f63effed58dba34363 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 18 May 2026 18:45:02 +0530 Subject: [PATCH 3/9] refactor(config): move torch_dtype to model and use training fp16/bf16 flags Signed-off-by: Sharvari Medhe --- QEfficient/cloud/finetune_experimental.py | 2 - .../experimental/configs/sft_ddp_config.yaml | 3 + .../sft_single_device_alpaca_config.yaml | 3 + ...t_single_device_custom_dataset_config.yaml | 3 + .../sft_single_device_gsm8k_config.yaml | 12 +++- .../experimental/core/config_manager.py | 55 +++++++++++-------- .../experimental/core/utils/constants.py | 2 +- .../core/utils/training_config_utils.py | 10 ---- .../experimental/examples/example_config.yaml | 3 + .../experimental/tests/test_config.yaml | 4 +- .../experimental/tests/test_config_manager.py | 25 ++++++--- 11 files changed, 78 insertions(+), 44 deletions(-) diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py index e54138e4b2..b8b6d9fe5a 100644 --- a/QEfficient/cloud/finetune_experimental.py +++ b/QEfficient/cloud/finetune_experimental.py @@ -275,9 +275,7 @@ def _create_trainer( # Clean up training config: remove fields that shouldn't be passed to TrainingArguments training_config.pop("device", None) training_config.pop("log_file_name", None) - # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config training_config.pop("deepspeed_config", None) - training_config.pop("torch_dtype", None) # Remove PP-specific fields as they're handled via device_map in model loading training_config.pop("pp_degree", None) diff --git a/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml b/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml index 9094ee4b0f..c43ae4b9a5 100644 --- a/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_ddp_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 # LoRA rank @@ -42,6 +43,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training torch_compile: False # Whether to use torch.compile diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml index 2bdf800bc5..e1c435e1c0 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_alpaca_config.yaml @@ -9,6 +9,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 16 @@ -30,6 +31,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 2 # Number of steps to accumulate gradients per_device_train_batch_size: 2 # Batch size per device during training num_train_epochs: 1 diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml index fbdcc88d6b..515a208f54 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_custom_dataset_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 @@ -31,6 +32,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training num_train_epochs: 1 diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml index 86c9ec4d13..1a42f99a00 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml @@ -9,6 +9,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float32" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 # LoRA rank @@ -35,6 +36,7 @@ dataset: completion_template: "{answer}" # Model will be trained on this part. config_name: "main" # Config name for the dataset data_seed: 42 # Random seed for dataset shuffling + dataset_num_samples: 10 # Training configuration @@ -42,8 +44,10 @@ training: type: "sft" # type of training gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training - num_train_epochs: 1 + num_train_epochs: 3 torch_compile: False # Whether to use torch.compile + fp16: false + bf16: false # Optimizer configuration optimizers: @@ -57,3 +61,9 @@ callbacks: early_stopping: early_stopping_patience: 3 # Number of epochs to wait before stopping training early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement + qaic_op_by_op_verifier_callback: + start_step: 1 + end_step: 5 + trace_dir: "./qaic_op_by_op_traces" + atol: 0.1 + rtol: 1e-5 diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index f500aeb2c8..041a90686e 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -265,6 +265,13 @@ class ModelConfig: default=None, metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, ) + torch_dtype: Optional[str] = field( + default="float16", + metadata={ + "help": "Torch dtype passed to model.from_pretrained " + "(e.g., 'float16', 'bfloat16', 'float32', or 'auto')." + }, + ) @dataclass @@ -418,9 +425,13 @@ class TrainingConfig: default="qaic", metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, ) - torch_dtype: str = field( - default="fp16", - metadata={"help": "The torch data type to use for model weights (e.g., 'fp32', 'fp16', 'bf16')."}, + fp16: bool = field( + default=True, + metadata={"help": "Whether to enable fp16 mixed precision in training arguments."}, + ) + bf16: bool = field( + default=False, + metadata={"help": "Whether to enable bf16 mixed precision in training arguments."}, ) torch_compile: bool = field( default=False, @@ -819,19 +830,28 @@ def validate_config(self) -> None: self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") - # ---------- Training ---------- + # ---------- Model ---------- # torch_dtype validation - torch_dtype = training.get("torch_dtype") - valid_dtypes = {"fp16", "bf16", "fp32"} + torch_dtype = model.get("torch_dtype") + valid_dtypes = {"float16", "bfloat16", "float32", "auto", "fp16", "bf16", "fp32"} self._push( errors, not torch_dtype, - "training.torch_dtype is required.", + "model.torch_dtype is required.", ) self._push( errors, torch_dtype and torch_dtype not in valid_dtypes, - f"training.torch_dtype must be one of {valid_dtypes}.", + f"model.torch_dtype must be one of {valid_dtypes}.", + ) + + # ---------- Training ---------- + fp16 = bool(training.get("fp16", False)) + bf16 = bool(training.get("bf16", False)) + self._push( + errors, + fp16 and bf16, + "training.fp16 and training.bf16 cannot both be true.", ) # Batch sizes @@ -912,21 +932,12 @@ def get_dataset_config(self) -> Dict[str, Any]: def get_model_config(self) -> Dict[str, Any]: """ Get model configuration as dictionary. - - Automatically handles torch_dtype conversion from training config if not set in model config. """ - model_config = self.config.model - - # Get torch_dtype from training config and convert - # To do: check if it can be moved from training config to model config instead - if model_config.get("torch_dtype") is None: - training_config = self.get_training_config() - training_dtype = training_config.get("torch_dtype") - if training_dtype: - # Convert from training format (fp16/bf16) to model format (float16/bfloat16) - dtype_mapping = dtype_mapping = constants.DTYPE_MAPPING - model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto") - + model_config = dict(self.config.model) + dtype_mapping = constants.DTYPE_MAPPING + torch_dtype = model_config.get("torch_dtype") + if torch_dtype in dtype_mapping: + model_config["torch_dtype"] = dtype_mapping[torch_dtype] return model_config def to_dict(self) -> Dict[str, Any]: diff --git a/QEfficient/finetune/experimental/core/utils/constants.py b/QEfficient/finetune/experimental/core/utils/constants.py index 76b3ccd2b1..49aacfe189 100644 --- a/QEfficient/finetune/experimental/core/utils/constants.py +++ b/QEfficient/finetune/experimental/core/utils/constants.py @@ -5,4 +5,4 @@ # # ----------------------------------------------------------------------------- -DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16"} +DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16", "fp32": "float32"} diff --git a/QEfficient/finetune/experimental/core/utils/training_config_utils.py b/QEfficient/finetune/experimental/core/utils/training_config_utils.py index 1cd6704e44..b002365140 100644 --- a/QEfficient/finetune/experimental/core/utils/training_config_utils.py +++ b/QEfficient/finetune/experimental/core/utils/training_config_utils.py @@ -31,18 +31,8 @@ def prepare_training_config( # Get training config as dict and create mutable copy to avoid mutating original training_config = dict(config_manager.get_training_config()) - # Handle dtype conversion - # To do: (For Tanisha) Check if torch_dtype should rather be added directly in model_config only in config_manager.py - - torch_dtype = training_config.pop("torch_dtype", None) - if torch_dtype is None: - raise ValueError("'torch_dtype' field is required in training configuration. Expected one of: ['fp16', 'bf16']") - training_config[torch_dtype] = True training_config["data_seed"] = training_config.get("seed") - # Restoring the "torch_dtype" after torch_dtype conversion using the saved value - training_config["torch_dtype"] = torch_dtype - # Handle scheduler configuration scheduler_config = config_manager.get_scheduler_config() training_config.setdefault("lr_scheduler_type", scheduler_config.get("scheduler_name")) diff --git a/QEfficient/finetune/experimental/examples/example_config.yaml b/QEfficient/finetune/experimental/examples/example_config.yaml index 809a47ebd1..0b430534a2 100644 --- a/QEfficient/finetune/experimental/examples/example_config.yaml +++ b/QEfficient/finetune/experimental/examples/example_config.yaml @@ -13,6 +13,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 16 @@ -39,6 +40,8 @@ dataset: # Training configuration training: type: "sft" + fp16: true + bf16: false gradient_accumulation_steps: 2 # Number of steps to accumulate gradients per_device_train_batch_size: 2 # Batch size per device during training num_train_epochs: 2 diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml index aab402b483..efe642dbb9 100644 --- a/QEfficient/finetune/experimental/tests/test_config.yaml +++ b/QEfficient/finetune/experimental/tests/test_config.yaml @@ -10,6 +10,7 @@ model: model_type: "hf" auto_class_name: "AutoModelForCausalLM" model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + torch_dtype: "float16" use_peft: true peft_config: lora_r: 16 @@ -41,7 +42,8 @@ training: seed: 42 device: "qaic" do_eval: True - torch_dtype: "fp16" + fp16: true + bf16: false eval_strategy: "epoch" eval_steps: 100 per_device_train_batch_size: 1 diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index cc088d0bed..3a3d0398c4 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -229,10 +229,10 @@ def test_config(config_path): def test_torch_dtype_validation(): """Test that torch_dtype validation works correctly.""" - # Test with default config - should have torch_dtype set to fp16 by default + # Test with default config - should have model torch_dtype set to float16 by default config_manager = ConfigManager() - training_config = config_manager.get_training_config() - assert training_config.get("torch_dtype") == "fp16" + model_config = config_manager.get_model_config() + assert model_config.get("torch_dtype") == "float16" # Validation should pass with default config config_manager.validate_config() # Should not raise @@ -240,11 +240,11 @@ def test_torch_dtype_validation(): def test_torch_dtype_invalid(): """Test that invalid torch_dtype raises validation error.""" - from QEfficient.finetune.experimental.core.config_manager import MasterConfig, TrainingConfig + from QEfficient.finetune.experimental.core.config_manager import MasterConfig, ModelConfig - # Create config with invalid torch_dtype - training_config = TrainingConfig(torch_dtype="invalid_dtype") - master_config = MasterConfig(training=training_config) + # Create config with invalid model torch_dtype + model_config = ModelConfig(torch_dtype="invalid_dtype") + master_config = MasterConfig(model=model_config) config_manager = ConfigManager(config=master_config) # Validation should fail @@ -252,3 +252,14 @@ def test_torch_dtype_invalid(): config_manager.validate_config() assert "torch_dtype must be one of" in str(exc_info.value) + + +def test_fp16_bf16_mutually_exclusive(): + training_config = TrainingConfig(fp16=True, bf16=True) + master_config = MasterConfig(training=training_config) + config_manager = ConfigManager(config=master_config) + + with pytest.raises(ValueError) as exc_info: + config_manager.validate_config() + + assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info.value) From 7fbf12d77c7d7ba79ff8a9f066583725144796de Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 18 May 2026 18:45:08 +0530 Subject: [PATCH 4/9] fix(callbacks): harden qaic op-by-op verifier lifecycle and type handling Signed-off-by: Sharvari Medhe --- .../finetune/experimental/core/callbacks.py | 41 +++++-- .../experimental/tests/test_callback.py | 112 +++++++++++++++++- 2 files changed, 145 insertions(+), 8 deletions(-) diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index 0ad7acc40b..c6d16a1d57 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any, Dict, Optional +import torch from transformers import ( DefaultFlowCallback, EarlyStoppingCallback, @@ -388,11 +389,28 @@ def __init__(self, *args, **kwargs): """ " Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings. """ - self.start_step = kwargs.get("start_step", -1) - self.end_step = kwargs.get("end_step", -1) + try: + self.start_step = int(kwargs.get("start_step", -1)) + self.end_step = int(kwargs.get("end_step", -1)) + self.atol = float(kwargs.get("atol", 1e-1)) + self.rtol = float(kwargs.get("rtol", 1e-5)) + except (TypeError, ValueError) as e: + raise ValueError( + "qaic_op_by_op_verifier_callback expects numeric values for " + "start_step, end_step, atol, and rtol." + ) from e self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") - self.atol = kwargs.get("atol", 1e-1) - self.rtol = kwargs.get("rtol", 1e-5) + self.op_verifier_ctx_step = None + + @staticmethod + def _resolve_ref_dtype(args: TrainingArguments) -> torch.dtype: + # Keep reference dtype aligned with training precision to avoid type mismatches + # in ops that require matching input/grad dtypes (e.g. embedding backward). + if getattr(args, "bf16", False): + return torch.bfloat16 + if getattr(args, "fp16", False): + return torch.float16 + return torch.float32 def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ @@ -400,11 +418,18 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T several inputs. """ if self.start_step <= state.global_step < self.end_step: + if getattr(args, "fp16", False): + raise RuntimeError( + "qaic_op_by_op_verifier_callback is not supported with fp16/GradScaler training. " + "Set training.fp16=false (and optionally training.bf16=true if supported) " + "when using this callback." + ) self.op_verifier_ctx_step = get_op_verifier_ctx( use_op_by_op_verifier=True, device_type="qaic", dump_dir=self.trace_dir, step=state.global_step, + ref_dtype=self._resolve_ref_dtype(args), atol=self.atol, rtol=self.rtol, ) @@ -415,9 +440,11 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs. """ - if self.start_step <= state.global_step < self.end_step: - if self.op_verifier_ctx_step is not None: - self.op_verifier_ctx_step.__exit__(None, None, None) + # Always close a previously-entered verifier context, even when the + # post-step global_step moved to end_step. + if self.op_verifier_ctx_step is not None: + self.op_verifier_ctx_step.__exit__(None, None, None) + self.op_verifier_ctx_step = None def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = None) -> None: diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 854ef7382e..35a5772dfb 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -8,10 +8,11 @@ from types import SimpleNamespace import pytest +import torch from transformers import TrainerCallback from QEfficient.finetune.experimental.core import callbacks as callbacks_module -from QEfficient.finetune.experimental.core.callbacks import QAICProfilerCallback +from QEfficient.finetune.experimental.core.callbacks import QAICOpByOpVerifierCallback, QAICProfilerCallback from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry @@ -236,3 +237,112 @@ def test_qaic_profiler_resolves_rank_at_start_time(monkeypatch): callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) assert calls == [(True, "qaic:11", "/tmp/hw-trace")] + + +def test_qaic_op_by_op_verifier_on_step_end_without_initialized_ctx(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=5, trace_dir="/tmp/op-trace") + state = SimpleNamespace(global_step=1) + + # Should not raise when on_step_end is hit before any context is initialized. + callback.on_step_end(args=None, state=state, control=None) + + +def test_qaic_op_by_op_verifier_casts_numeric_config(monkeypatch): + captured = {} + + class _DummyCtx: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def _mock_get_op_verifier_ctx(**kwargs): + captured.update(kwargs) + return _DummyCtx() + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", _mock_get_op_verifier_ctx) + + callback = QAICOpByOpVerifierCallback( + start_step="0", + end_step="2", + trace_dir="/tmp/op-trace", + atol="0.1", + rtol="1e-5", + ) + callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) + + assert isinstance(captured["atol"], float) + assert isinstance(captured["rtol"], float) + assert captured["atol"] == 0.1 + assert captured["rtol"] == 1e-5 + + +@pytest.mark.parametrize( + "args,expected_dtype", + [ + (SimpleNamespace(fp16=True, bf16=False), torch.float16), + (SimpleNamespace(fp16=False, bf16=True), torch.bfloat16), + (SimpleNamespace(fp16=False, bf16=False), torch.float32), + ], +) +def test_qaic_op_by_op_verifier_uses_training_precision_for_ref_dtype(monkeypatch, args, expected_dtype): + captured = {} + + class _DummyCtx: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def _mock_get_op_verifier_ctx(**kwargs): + captured.update(kwargs) + return _DummyCtx() + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", _mock_get_op_verifier_ctx) + + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + callback.on_step_begin(args=args, state=SimpleNamespace(global_step=0), control=None) + + assert captured["ref_dtype"] == expected_dtype + + +def test_qaic_op_by_op_verifier_rejects_fp16_mode(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + + with pytest.raises(RuntimeError, match="not supported with fp16/GradScaler"): + callback.on_step_begin( + args=SimpleNamespace(fp16=True, bf16=False), + state=SimpleNamespace(global_step=0), + control=None, + ) + + +def test_qaic_op_by_op_verifier_exits_ctx_when_global_step_reaches_end_step(monkeypatch): + events = [] + + class _DummyCtx: + def __enter__(self): + events.append("enter") + return self + + def __exit__(self, exc_type, exc, tb): + events.append("exit") + return False + + monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", lambda **kwargs: _DummyCtx()) + + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + + # Enter at step 1 (still inside [start_step, end_step) ). + callback.on_step_begin( + args=SimpleNamespace(fp16=False, bf16=False), + state=SimpleNamespace(global_step=1), + control=None, + ) + # At on_step_end, HF Trainer may already report global_step == end_step. + callback.on_step_end(args=None, state=SimpleNamespace(global_step=2), control=None) + + assert events == ["enter", "exit"] + assert callback.op_verifier_ctx_step is None From 1b3f8b6ab85682b34838d618db7f0c516d8a7371 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 18 May 2026 20:02:43 +0530 Subject: [PATCH 5/9] adding format fixes and saving dir logic Signed-off-by: Sharvari Medhe --- .../sft_single_device_gsm8k_config.yaml | 15 +++++--------- .../finetune/experimental/core/callbacks.py | 17 +++++++++++++--- .../experimental/core/config_manager.py | 3 +-- .../experimental/tests/test_callback.py | 20 ++++++++++++++++++- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml index 1a42f99a00..e9c23c3662 100644 --- a/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml +++ b/QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml @@ -9,7 +9,7 @@ model: model_type: "hf" # Hugging Face model auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name - torch_dtype: "float32" + torch_dtype: "float16" use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning) peft_config: lora_r: 8 # LoRA rank @@ -36,7 +36,6 @@ dataset: completion_template: "{answer}" # Model will be trained on this part. config_name: "main" # Config name for the dataset data_seed: 42 # Random seed for dataset shuffling - dataset_num_samples: 10 # Training configuration @@ -44,9 +43,9 @@ training: type: "sft" # type of training gradient_accumulation_steps: 1 # Number of steps to accumulate gradients per_device_train_batch_size: 1 # Batch size per device during training - num_train_epochs: 3 + num_train_epochs: 1 torch_compile: False # Whether to use torch.compile - fp16: false + fp16: true bf16: false # Optimizer configuration @@ -61,9 +60,5 @@ callbacks: early_stopping: early_stopping_patience: 3 # Number of epochs to wait before stopping training early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement - qaic_op_by_op_verifier_callback: - start_step: 1 - end_step: 5 - trace_dir: "./qaic_op_by_op_traces" - atol: 0.1 - rtol: 1e-5 + + diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index c6d16a1d57..673b76cca7 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -396,10 +396,15 @@ def __init__(self, *args, **kwargs): self.rtol = float(kwargs.get("rtol", 1e-5)) except (TypeError, ValueError) as e: raise ValueError( - "qaic_op_by_op_verifier_callback expects numeric values for " - "start_step, end_step, atol, and rtol." + "qaic_op_by_op_verifier_callback expects numeric values for start_step, end_step, atol, and rtol." ) from e - self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") + trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") + expanded_trace_dir = os.path.expanduser(trace_dir) + if os.path.isabs(expanded_trace_dir): + self.trace_dir = os.path.abspath(expanded_trace_dir) + else: + output_dir = os.environ.get("OUTPUT_DIR", ".") + self.trace_dir = os.path.abspath(os.path.join(output_dir, expanded_trace_dir)) self.op_verifier_ctx_step = None @staticmethod @@ -424,6 +429,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T "Set training.fp16=false (and optionally training.bf16=true if supported) " "when using this callback." ) + logger.log_rank_zero( + "QAIC OpByOp verifier active: " + f"step={state.global_step}, " + f"window=[{self.start_step}, {self.end_step}), " + f"dump_dir={self.trace_dir}/mismatches/step_{state.global_step}" + ) self.op_verifier_ctx_step = get_op_verifier_ctx( use_op_by_op_verifier=True, device_type="qaic", diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 041a90686e..2bce40e73d 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -268,8 +268,7 @@ class ModelConfig: torch_dtype: Optional[str] = field( default="float16", metadata={ - "help": "Torch dtype passed to model.from_pretrained " - "(e.g., 'float16', 'bfloat16', 'float32', or 'auto')." + "help": "Torch dtype passed to model.from_pretrained (e.g., 'float16', 'bfloat16', 'float32', or 'auto')." }, ) diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 35a5772dfb..52f7f33faa 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import os from types import SimpleNamespace import pytest @@ -247,6 +248,24 @@ def test_qaic_op_by_op_verifier_on_step_end_without_initialized_ctx(): callback.on_step_end(args=None, state=state, control=None) +def test_qaic_op_by_op_verifier_default_trace_dir_is_under_output_dir(monkeypatch): + monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1) + assert callback.trace_dir == os.path.abspath("/tmp/train_out/qaic_op_by_op_traces") + + +def test_qaic_op_by_op_verifier_relative_trace_dir_is_under_output_dir(monkeypatch): + monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir="./custom-op-trace") + assert callback.trace_dir == os.path.abspath("/tmp/train_out/custom-op-trace") + + +def test_qaic_op_by_op_verifier_absolute_trace_dir_is_preserved(monkeypatch): + monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir="/var/tmp/op-trace") + assert callback.trace_dir == os.path.abspath("/var/tmp/op-trace") + + def test_qaic_op_by_op_verifier_casts_numeric_config(monkeypatch): captured = {} @@ -281,7 +300,6 @@ def _mock_get_op_verifier_ctx(**kwargs): @pytest.mark.parametrize( "args,expected_dtype", [ - (SimpleNamespace(fp16=True, bf16=False), torch.float16), (SimpleNamespace(fp16=False, bf16=True), torch.bfloat16), (SimpleNamespace(fp16=False, bf16=False), torch.float32), ], From 6d905c7bd8951cb9d1a55edd9f998993bb336027 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 18 May 2026 21:16:12 +0530 Subject: [PATCH 6/9] adding test suite changes and documentation Signed-off-by: Sharvari Medhe --- .../experimental/core/config_manager.py | 16 +++++++++++++++- .../experimental/tests/test_config_manager.py | 17 ++++++++++++----- docs/source/config.md | 5 +++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 2bce40e73d..74ea0d1904 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -556,7 +556,8 @@ def __init__( try: self.validate_config() except Exception as e: - logger.log_rank_zero(f"Config validation failed with error: {e}") + logger.log_rank_zero(f"Config validation failed with error: {e}", level=logging.ERROR) + raise def _build_cli_parser(self) -> HfArgumentParser: return HfArgumentParser( @@ -852,6 +853,19 @@ def validate_config(self) -> None: fp16 and bf16, "training.fp16 and training.bf16 cannot both be true.", ) + callbacks_cfg = getattr(cfg, "callbacks", {}) + if isinstance(callbacks_cfg, dict): + callback_dict = {} + nested_callbacks = callbacks_cfg.get("callbacks") + if isinstance(nested_callbacks, dict): + callback_dict.update(nested_callbacks) + callback_dict.update({k: v for k, v in callbacks_cfg.items() if k != "callbacks"}) + self._push( + errors, + fp16 and isinstance(callback_dict, dict) and "qaic_op_by_op_verifier_callback" in callback_dict, + "qaic_op_by_op_verifier_callback is not compatible with training.fp16=true. " + "Set training.fp16=false when using this callback.", + ) # Batch sizes self._push( diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index 3a3d0398c4..aa4d052ecf 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -245,11 +245,8 @@ def test_torch_dtype_invalid(): # Create config with invalid model torch_dtype model_config = ModelConfig(torch_dtype="invalid_dtype") master_config = MasterConfig(model=model_config) - config_manager = ConfigManager(config=master_config) - - # Validation should fail with pytest.raises(ValueError) as exc_info: - config_manager.validate_config() + ConfigManager(config=master_config) assert "torch_dtype must be one of" in str(exc_info.value) @@ -257,9 +254,19 @@ def test_torch_dtype_invalid(): def test_fp16_bf16_mutually_exclusive(): training_config = TrainingConfig(fp16=True, bf16=True) master_config = MasterConfig(training=training_config) + with pytest.raises(ValueError) as exc_info: + ConfigManager(config=master_config) + + assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info.value) + + +def test_qaic_op_by_op_verifier_disallowed_with_fp16(): + training_config = TrainingConfig(fp16=True, bf16=False) + master_config = MasterConfig(training=training_config) config_manager = ConfigManager(config=master_config) + config_manager.update_config({"callbacks": {"qaic_op_by_op_verifier_callback": {"start_step": 0, "end_step": 1}}}) with pytest.raises(ValueError) as exc_info: config_manager.validate_config() - assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info.value) + assert "qaic_op_by_op_verifier_callback is not compatible with training.fp16=true" in str(exc_info.value) diff --git a/docs/source/config.md b/docs/source/config.md index 9fc9ecf554..51abbacd8d 100644 --- a/docs/source/config.md +++ b/docs/source/config.md @@ -258,6 +258,11 @@ Callbacks allow custom actions during training, such as logging, early stopping, | `QAICProfilerCallback` | Profiles QAIC devices over a specified training step range. | | `QAICOpByOpVerifierCallback` | Verifies QAIC operations step-by-step for correctness and debugging. | +**QAIC callback usage recommendations** + +- Use `qaic_profiler_callback` for only `1-3` steps. +- Use `qaic_op_by_op_verifier_callback` with `training.fp16: false` and `model.torch_dtype: fp32`, for only `1-3` steps. + **References to some commonly used Hugging Face callbacks**: https://huggingface.co/docs/transformers/en/main_classes/callback *** From 0b7a723fa829e0989958f747e79731d728d58405 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Tue, 19 May 2026 15:25:12 +0530 Subject: [PATCH 7/9] adding test case changes in callbacks Signed-off-by: Sharvari Medhe --- .../experimental/core/config_manager.py | 1 - .../experimental/tests/test_callback.py | 83 ++++++++++++------- .../experimental/tests/test_config_manager.py | 44 ++++++++-- docs/source/config.md | 9 +- 4 files changed, 98 insertions(+), 39 deletions(-) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 74ea0d1904..f9c6add9e8 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -557,7 +557,6 @@ def __init__( self.validate_config() except Exception as e: logger.log_rank_zero(f"Config validation failed with error: {e}", level=logging.ERROR) - raise def _build_cli_parser(self) -> HfArgumentParser: return HfArgumentParser( diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py index 52f7f33faa..63b2883a15 100644 --- a/QEfficient/finetune/experimental/tests/test_callback.py +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- import os +import shutil +from pathlib import Path from types import SimpleNamespace import pytest @@ -16,6 +18,13 @@ from QEfficient.finetune.experimental.core.callbacks import QAICOpByOpVerifierCallback, QAICProfilerCallback from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +PROJECT_ROOT = Path(__file__).resolve().parents[4] +OUTPUT_DIR = PROJECT_ROOT / "training_results" +QAIC_PROFILER_TRACE_DIR = OUTPUT_DIR / "hw-trace" +QAIC_OP_TRACE_DIR = OUTPUT_DIR / "qaic_op_by_op_traces" +QAIC_CUSTOM_OP_TRACE_DIR = OUTPUT_DIR / "custom-op-trace" +QAIC_ABSOLUTE_OP_TRACE_DIR = OUTPUT_DIR / "absolute-op-trace" + class ModelSummaryCallback(TrainerCallback): def __init__(self): @@ -45,7 +54,7 @@ def __init__(self): "name": "qaic_profiler_callback", "start_step": 0, "end_step": 1, - "trace_dir": "/tmp/hw-trace", + "trace_dir": str(QAIC_PROFILER_TRACE_DIR), "device_ids": [0], }, id="qaic_profiler", @@ -61,6 +70,22 @@ def __init__(self): } +@pytest.fixture(autouse=True, scope="module") +def setup_training_output_dir(): + prev_output_dir = os.environ.get("OUTPUT_DIR") + shutil.rmtree(OUTPUT_DIR, ignore_errors=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + os.environ["OUTPUT_DIR"] = str(OUTPUT_DIR) + try: + yield + finally: + if prev_output_dir is None: + os.environ.pop("OUTPUT_DIR", None) + else: + os.environ["OUTPUT_DIR"] = prev_output_dir + shutil.rmtree(OUTPUT_DIR, ignore_errors=True) + + @pytest.mark.parametrize("config", CALLBACK_CONFIGS) def test_callbacks(config): """Test that registered callbacks that can be created with their configs.""" @@ -84,8 +109,8 @@ def test_callbacks_registery(callback_name, callback_class): def test_qaic_profiler_uses_user_trace_dir(): - callback = QAICProfilerCallback(trace_dir="~/my_custom_hw_trace") - assert callback.trace_dir == "~/my_custom_hw_trace" + callback = QAICProfilerCallback(trace_dir=str(QAIC_PROFILER_TRACE_DIR)) + assert callback.trace_dir == str(QAIC_PROFILER_TRACE_DIR) def test_qaic_profiler_starts_with_trace_dir(monkeypatch): @@ -99,13 +124,13 @@ def _mock_start(use_profiler, device_type, trace_dir=None): monkeypatch.setattr(callbacks_module, "init_qaic_profiling", _mock_start) - callback = QAICProfilerCallback(start_step=3, end_step=9, trace_dir="/tmp/hw-trace", device_ids=[2]) + callback = QAICProfilerCallback(start_step=3, end_step=9, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[2]) state = SimpleNamespace(global_step=3) callback.on_step_begin(args=None, state=state, control=None) assert callback._profile_started is True - assert calls == [(True, "qaic:2", "/tmp/hw-trace")] + assert calls == [(True, "qaic:2", str(QAIC_PROFILER_TRACE_DIR))] def test_qaic_profiler_stops_once_at_end_step(monkeypatch): @@ -125,7 +150,7 @@ def test_qaic_profiler_stops_once_at_end_step(monkeypatch): lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), ) - callback = QAICProfilerCallback(start_step=1, end_step=2, trace_dir="/tmp/hw-trace", device_ids=[0]) + callback = QAICProfilerCallback(start_step=1, end_step=2, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[0]) callback.on_step_begin(args=None, state=SimpleNamespace(global_step=1), control=None) callback.on_step_end(args=None, state=SimpleNamespace(global_step=2), control=None) @@ -166,10 +191,10 @@ def test_qaic_profiler_uses_local_rank_when_device_ids_not_set(monkeypatch): lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), ) - callback = QAICProfilerCallback(start_step=0, trace_dir="/tmp/hw-trace") + callback = QAICProfilerCallback(start_step=0, trace_dir=str(QAIC_PROFILER_TRACE_DIR)) callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) - assert calls == [(True, "qaic:3", "/tmp/hw-trace")] + assert calls == [(True, "qaic:3", str(QAIC_PROFILER_TRACE_DIR))] def test_qaic_profiler_maps_rank_to_device_id(monkeypatch): @@ -183,10 +208,10 @@ def test_qaic_profiler_maps_rank_to_device_id(monkeypatch): lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), ) - callback = QAICProfilerCallback(start_step=5, trace_dir="/tmp/hw-trace", device_ids=[10, 11]) + callback = QAICProfilerCallback(start_step=5, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[10, 11]) callback.on_step_begin(args=None, state=SimpleNamespace(global_step=5), control=None) - assert calls == [(True, "qaic:11", "/tmp/hw-trace")] + assert calls == [(True, "qaic:11", str(QAIC_PROFILER_TRACE_DIR))] def test_qaic_profiler_invalid_step_range_raises(): @@ -213,11 +238,14 @@ def _mock_start(use_profiler, device_type, trace_dir=None): lambda use_profiler, device_type: stop_calls.append((use_profiler, device_type)), ) - callback = QAICProfilerCallback(start_step=0, end_step=1, trace_dir="/tmp/hw-trace", device_ids=[0, 1]) + callback = QAICProfilerCallback(start_step=0, end_step=1, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[0, 1]) callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) callback.on_step_end(args=None, state=SimpleNamespace(global_step=1), control=None) - assert start_calls == [(True, "qaic:0", "/tmp/hw-trace"), (True, "qaic:1", "/tmp/hw-trace")] + assert start_calls == [ + (True, "qaic:0", str(QAIC_PROFILER_TRACE_DIR)), + (True, "qaic:1", str(QAIC_PROFILER_TRACE_DIR)), + ] assert stop_calls == [(True, "qaic:0")] @@ -233,37 +261,34 @@ def test_qaic_profiler_resolves_rank_at_start_time(monkeypatch): lambda use_profiler, device_type, trace_dir=None: calls.append((use_profiler, device_type, trace_dir)), ) - callback = QAICProfilerCallback(start_step=0, trace_dir="/tmp/hw-trace", device_ids=[10, 11]) + callback = QAICProfilerCallback(start_step=0, trace_dir=str(QAIC_PROFILER_TRACE_DIR), device_ids=[10, 11]) rank_state["local_rank"] = 1 callback.on_step_begin(args=None, state=SimpleNamespace(global_step=0), control=None) - assert calls == [(True, "qaic:11", "/tmp/hw-trace")] + assert calls == [(True, "qaic:11", str(QAIC_PROFILER_TRACE_DIR))] def test_qaic_op_by_op_verifier_on_step_end_without_initialized_ctx(): - callback = QAICOpByOpVerifierCallback(start_step=0, end_step=5, trace_dir="/tmp/op-trace") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=5, trace_dir=str(QAIC_OP_TRACE_DIR)) state = SimpleNamespace(global_step=1) # Should not raise when on_step_end is hit before any context is initialized. callback.on_step_end(args=None, state=state, control=None) -def test_qaic_op_by_op_verifier_default_trace_dir_is_under_output_dir(monkeypatch): - monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") +def test_qaic_op_by_op_verifier_default_trace_dir_is_under_output_dir(): callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1) - assert callback.trace_dir == os.path.abspath("/tmp/train_out/qaic_op_by_op_traces") + assert callback.trace_dir == os.path.abspath(str(OUTPUT_DIR / "qaic_op_by_op_traces")) -def test_qaic_op_by_op_verifier_relative_trace_dir_is_under_output_dir(monkeypatch): - monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") +def test_qaic_op_by_op_verifier_relative_trace_dir_is_under_output_dir(): callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir="./custom-op-trace") - assert callback.trace_dir == os.path.abspath("/tmp/train_out/custom-op-trace") + assert callback.trace_dir == os.path.abspath(str(QAIC_CUSTOM_OP_TRACE_DIR)) -def test_qaic_op_by_op_verifier_absolute_trace_dir_is_preserved(monkeypatch): - monkeypatch.setenv("OUTPUT_DIR", "/tmp/train_out") - callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir="/var/tmp/op-trace") - assert callback.trace_dir == os.path.abspath("/var/tmp/op-trace") +def test_qaic_op_by_op_verifier_absolute_trace_dir_is_preserved(): + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=1, trace_dir=str(QAIC_ABSOLUTE_OP_TRACE_DIR)) + assert callback.trace_dir == os.path.abspath(str(QAIC_ABSOLUTE_OP_TRACE_DIR)) def test_qaic_op_by_op_verifier_casts_numeric_config(monkeypatch): @@ -285,7 +310,7 @@ def _mock_get_op_verifier_ctx(**kwargs): callback = QAICOpByOpVerifierCallback( start_step="0", end_step="2", - trace_dir="/tmp/op-trace", + trace_dir=str(QAIC_OP_TRACE_DIR), atol="0.1", rtol="1e-5", ) @@ -320,14 +345,14 @@ def _mock_get_op_verifier_ctx(**kwargs): monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", _mock_get_op_verifier_ctx) - callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) callback.on_step_begin(args=args, state=SimpleNamespace(global_step=0), control=None) assert captured["ref_dtype"] == expected_dtype def test_qaic_op_by_op_verifier_rejects_fp16_mode(): - callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) with pytest.raises(RuntimeError, match="not supported with fp16/GradScaler"): callback.on_step_begin( @@ -351,7 +376,7 @@ def __exit__(self, exc_type, exc, tb): monkeypatch.setattr(callbacks_module, "get_op_verifier_ctx", lambda **kwargs: _DummyCtx()) - callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir="/tmp/op-trace") + callback = QAICOpByOpVerifierCallback(start_step=0, end_step=2, trace_dir=str(QAIC_OP_TRACE_DIR)) # Enter at step 1 (still inside [start_step, end_step) ). callback.on_step_begin( diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py index aa4d052ecf..f5c3cb00c3 100644 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -238,26 +238,54 @@ def test_torch_dtype_validation(): config_manager.validate_config() # Should not raise -def test_torch_dtype_invalid(): - """Test that invalid torch_dtype raises validation error.""" +def test_torch_dtype_invalid(monkeypatch): + """Test that invalid torch_dtype is reported via exception or logged validation failure.""" + from QEfficient.finetune.experimental.core import config_manager as config_manager_module from QEfficient.finetune.experimental.core.config_manager import MasterConfig, ModelConfig + captured_logs = [] + + def _capture_log(message, level=None): + captured_logs.append((str(message), level)) + + monkeypatch.setattr(config_manager_module.logger, "log_rank_zero", _capture_log) + # Create config with invalid model torch_dtype model_config = ModelConfig(torch_dtype="invalid_dtype") master_config = MasterConfig(model=model_config) - with pytest.raises(ValueError) as exc_info: + try: ConfigManager(config=master_config) + except ValueError as exc_info: + assert "torch_dtype must be one of" in str(exc_info) + return + + assert any( + "Config validation failed with error" in msg and "torch_dtype must be one of" in msg for msg, _ in captured_logs + ), "Expected torch_dtype validation failure to be logged when ConfigManager does not raise." - assert "torch_dtype must be one of" in str(exc_info.value) +def test_fp16_bf16_mutually_exclusive(monkeypatch): + from QEfficient.finetune.experimental.core import config_manager as config_manager_module + + captured_logs = [] + + def _capture_log(message, level=None): + captured_logs.append((str(message), level)) + + monkeypatch.setattr(config_manager_module.logger, "log_rank_zero", _capture_log) -def test_fp16_bf16_mutually_exclusive(): training_config = TrainingConfig(fp16=True, bf16=True) master_config = MasterConfig(training=training_config) - with pytest.raises(ValueError) as exc_info: + try: ConfigManager(config=master_config) - - assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info.value) + except ValueError as exc_info: + assert "training.fp16 and training.bf16 cannot both be true" in str(exc_info) + return + + assert any( + "Config validation failed with error" in msg and "training.fp16 and training.bf16 cannot both be true" in msg + for msg, _ in captured_logs + ), "Expected fp16/bf16 mutual-exclusion validation failure to be logged when ConfigManager does not raise." def test_qaic_op_by_op_verifier_disallowed_with_fp16(): diff --git a/docs/source/config.md b/docs/source/config.md index 51abbacd8d..1545710cba 100644 --- a/docs/source/config.md +++ b/docs/source/config.md @@ -18,6 +18,7 @@ Model-related parameters for loading and fine-tuning. | `use_cache` | `false` | Uses the past key/values cache for faster decoding during generation. | | `attn_implementation` | `"sdpa"` | Attention implementation. Common values: `sdpa`, `eager`. | | `device_map` | `None` | Specifies how to distribute the model across devices. | +| `torch_dtype` | `float16` | Torch dtype passed to `model.from_pretrained` (for example `float16`, `bfloat16`, `float32`, `auto`, `fp16`, `bf16`, `fp32`). | | `use_peft` | `true` | Enables PEFT for parameter-efficient fine-tuning. | | `peft_config` | - | Defines LoRA parameters when `use_peft` is true. | @@ -173,7 +174,8 @@ This section defines core parameters for fine-tuning and evaluation. | `do_eval` | `true` | Enables evaluation during training. | | `eval_strategy` | `epoch` | When to run evaluation. | | `gradient_accumulation_steps` | `1` | Accumulates gradients over multiple steps. | -| `dtype` | `fp16` | Mixed precision setting. | +| `fp16` | `true` | Enables fp16 mixed precision in training arguments. | +| `bf16` | `false` | Enables bf16 mixed precision in training arguments. | | `seed` | `42` | Random seed for reproducibility. | | `device` | `qaic` | Device to use for training. | | `per_device_train_batch_size` | `1` | Batch size per device during training. | @@ -210,6 +212,11 @@ This section defines core parameters for fine-tuning and evaluation. Optional distributed configs: FSDP, DeepSpeed, or DDP for multi-QAIC or large-scale training. +Precision note: +- `model.torch_dtype` controls model loading dtype. +- `training.fp16` and `training.bf16` control training-time mixed precision flags. +- `training.fp16` and `training.bf16` cannot both be `true`. + 📁 **Output Directory Structure** output_dir/ From 5f003a4a6e8fee3d3f13e55ff877f65538d251e8 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Tue, 19 May 2026 17:46:56 +0530 Subject: [PATCH 8/9] validate_config will now raise error if the config is not validated Signed-off-by: Sharvari Medhe --- QEfficient/finetune/experimental/core/config_manager.py | 9 +++------ .../finetune/experimental/tests/test_integrated.py | 6 ++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index f9c6add9e8..7a86ff452f 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -552,11 +552,8 @@ def __init__( logger.log_rank_zero("Using default configuration...") self.config = asdict(self.config) self.config = MasterConfig(**self.config) - # Validate loaded config - try: - self.validate_config() - except Exception as e: - logger.log_rank_zero(f"Config validation failed with error: {e}", level=logging.ERROR) + + self.validate_config() def _build_cli_parser(self) -> HfArgumentParser: return HfArgumentParser( @@ -788,7 +785,7 @@ def validate_config(self) -> None: self._push(errors, not model.get("model_name"), "model.model_name is required.") # Device valid_devices = ["cpu", "cuda", "qaic"] - training_device = model.get("device", "qaic") + training_device = training.get("device", "qaic") if training_device not in valid_devices: self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") if training_device == "qaic": diff --git a/QEfficient/finetune/experimental/tests/test_integrated.py b/QEfficient/finetune/experimental/tests/test_integrated.py index 207cf458e6..c23029dda0 100644 --- a/QEfficient/finetune/experimental/tests/test_integrated.py +++ b/QEfficient/finetune/experimental/tests/test_integrated.py @@ -22,6 +22,7 @@ import torch from QEfficient.cloud.finetune_experimental import FineTuningPipeline +from QEfficient.finetune.experimental.core import config_manager as config_manager_module from QEfficient.finetune.experimental.core.config_manager import ( ConfigManager, DatasetConfig, @@ -73,6 +74,11 @@ # ============================================================================ +@pytest.fixture(autouse=True) +def bypass_nsp_free_check(monkeypatch): + monkeypatch.setattr(config_manager_module, "is_nsp_free", lambda: None) + + def clean_up(path): if os.path.isdir(path) and os.path.exists(path): shutil.rmtree(path) From 167148feb345841f9b932d5221dafe94f7aae63c Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Wed, 20 May 2026 10:51:25 +0530 Subject: [PATCH 9/9] removing redundant checks from config manager Signed-off-by: Sharvari Medhe --- QEfficient/finetune/experimental/core/config_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 7a86ff452f..3063d207c9 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -786,8 +786,7 @@ def validate_config(self) -> None: # Device valid_devices = ["cpu", "cuda", "qaic"] training_device = training.get("device", "qaic") - if training_device not in valid_devices: - self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") if training_device == "qaic": try: import torch_qaic # noqa: F401