From 68f1d73afbcde06b4206782c60612da0ae61d391 Mon Sep 17 00:00:00 2001 From: Alex Hunt Date: Thu, 7 May 2026 10:56:19 +0100 Subject: [PATCH 1/2] exllamav3 backend: support MTP draft model arch override MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds draft_arch_override and num_draft_tokens to DraftModelConfig so Qwen3.5/3.6 BF16 directories can be loaded as MTP-only draft models (arch_override="Qwen3_5MTPDraftModel"). Threads both options through to Config.from_directory and AsyncGenerator. If draft_arch_override is set but draft_model_name is omitted, treat the main model_directory as the source for the draft model. This covers the case where the same checkpoint contains both the trunk and the mtp.* tensors — no need to point at a separate directory or extract the MTP head into its own dir. --- backends/exllamav3/model.py | 39 +++++++++++++++++++++++++++++-------- common/config_models.py | 16 ++++++++++++++- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index f6303a8f..26573d29 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -156,26 +156,49 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") - self.use_draft_model = draft_args and draft_model_name + draft_arch_override = draft_args.get("draft_arch_override") + self._draft_args = draft_args - # Always disable draft if params are incorrectly configured - if draft_args and draft_model_name is None: + # Two ways to enable a draft model: + # 1) Separate dir+name (regular draft, any arch). + # 2) MTP head loaded from the main model's checkpoint: set draft_arch_override + # (e.g. "Qwen3_5MTPDraftModel") and leave draft_model_name unset. + self.use_draft_model = bool(draft_args) and bool(draft_model_name or draft_arch_override) + + # Misconfiguration: draft section present but no way to locate weights + if draft_args and not draft_model_name and not draft_arch_override: xlogger.warning( - "Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!" + "Draft model section is set but neither draft_model_name nor " + "draft_arch_override is provided. Disabling draft model. " + "Set draft_model_name to load from a separate directory, or " + "draft_arch_override (e.g. Qwen3_5MTPDraftModel) to load the " + "MTP head from the main model directory." ) self.use_draft_model = False if self.use_draft_model: - draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) - draft_model_path = draft_model_path / draft_model_name + if draft_model_name: + # Separate draft model directory + draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) + draft_model_path = draft_model_path / draft_model_name + else: + # MTP from the same dir as the main model — checkpoint has both trunk and + # mtp.* tensors; arch_override picks just the MTP weights via Qwen3_5MTPDraftConfig. + draft_model_path = model_directory + xlogger.info("Loading draft model from main model directory (self-spec)") + self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) self.draft_model_dir = draft_model_path - self.draft_config = Config.from_directory(str(draft_model_path.resolve())) + self.draft_config = Config.from_directory( + str(draft_model_path.resolve()), + arch_override=draft_arch_override, + ) self.draft_model = Model.from_config(self.draft_config) default_ndt = self.draft_model.caps.get("default_draft_size", 4) self.draft_num_tokens = draft_args.get("draft_num_tokens", default_ndt) xlogger.info(f"Using draft model: {str(draft_model_path.resolve())}") + if draft_arch_override: + xlogger.info(f"Draft arch override: {draft_arch_override}") else: self.draft_model = None self.draft_cache = None diff --git a/common/config_models.py b/common/config_models.py index 6d372757..4ed325db 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -357,7 +357,9 @@ class DraftModelConfig(BaseConfigModel): draft_model_name: Optional[str] = Field( None, description=( - "An initial draft model to load.\nEnsure the model is in the model directory." + "An initial draft model to load.\nEnsure the model is in the model directory.\n" + "Leave blank when using draft_arch_override to load an MTP head from the\n" + "main model directory (self-spec)." ), ) draft_rope_scale: Optional[float] = Field( @@ -400,6 +402,18 @@ class DraftModelConfig(BaseConfigModel): "(e.g. DFlash with 15 tokens by default) shorter drafts may be preferable." ), ) + draft_arch_override: Optional[str] = Field( + None, + description=( + "Override the architecture string read from the draft model's config.json.\n" + "Use 'Qwen3_5MTPDraftModel' to load only the MTP head. Two ways:\n" + " - With draft_model_name: load the MTP head from a separate directory\n" + " (e.g. point draft_model_name at the original BF16 Qwen3.6 repo).\n" + " - Without draft_model_name: load the MTP head from the SAME directory\n" + " as the main model, when that checkpoint already contains the mtp.*\n" + " tensors alongside the regular trunk weights." + ), + ) class SamplingConfig(BaseConfigModel): From 548c570344a58d1fe776d8027eb835a7d85d83af Mon Sep 17 00:00:00 2001 From: Alex Hunt Date: Fri, 8 May 2026 15:31:55 +0100 Subject: [PATCH 2/2] log draft acceptance rate metrics --- backends/exllamav2/model.py | 2 ++ backends/exllamav3/model.py | 2 ++ common/gen_logging.py | 10 ++++++++++ 3 files changed, 14 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ea0ecb4f..4574c29e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1228,6 +1228,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str): "finish_reason": finish_reason, "stop_str": stop_str, "full_text": full_text, + "accepted_draft_tokens": result.get("accepted_draft_tokens", 0), + "rejected_draft_tokens": result.get("rejected_draft_tokens", 0), } return finish_chunk diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 26573d29..5067d690 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -949,6 +949,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str): "finish_reason": finish_reason, "stop_str": stop_str, "full_text": full_text, + "accepted_draft_tokens": result.get("accepted_draft_tokens", 0), + "rejected_draft_tokens": result.get("rejected_draft_tokens", 0), } return finish_chunk diff --git a/common/gen_logging.py b/common/gen_logging.py index 8aa54dc3..f0b00fc5 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -82,6 +82,16 @@ def log_metrics( itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s") + # Add draft token acceptance rate if available + accepted_draft = metrics.get("accepted_draft_tokens", 0) + rejected_draft = metrics.get("rejected_draft_tokens", 0) + total_draft = accepted_draft + rejected_draft + if total_draft > 0: + acceptance_rate = round(accepted_draft / total_draft * 100, 1) + itemization.append( + f"Draft: {accepted_draft}/{total_draft} accepted ({acceptance_rate}% acceptance)" + ) + # Add context (original token count) if context_len: itemization.append(f"Context: {context_len} tokens")