From 189c4ed82691e05347035a93d5e561112d44be33 Mon Sep 17 00:00:00 2001 From: Dongji Gao Date: Tue, 17 Mar 2026 17:24:38 -0700 Subject: [PATCH 1/3] Add vLLM NeMo Speech LM plugin for multimodal inference Add vLLM plugin that enables fast inference for NeMo Speech LM models (encoder + projection + LLM) via vLLM's PagedAttention and continuous batching engine. Files: - nemo/collections/speechlm2/vllm/__init__.py: package marker - nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py: plugin registration (config, model, NemotronH patch) - nemo/collections/speechlm2/vllm/nemotron_v3/config.py: NeMoSpeechLMConfig (wraps text_config from LLM backbone) - nemo/collections/speechlm2/vllm/nemotron_v3/model.py: NeMoSpeechLMForConditionalGeneration (NeMo encoder + vLLM LLM) - pyproject.toml: register vllm.general_plugins entry point Validated on Open ASR Leaderboard (8 datasets, 82K samples): WER matches NeMo checkpoint within 0.1%. Signed-off-by: Dongji Gao --- nemo/collections/speechlm2/vllm/__init__.py | 0 .../speechlm2/vllm/nemotron_v3/__init__.py | 37 ++ .../speechlm2/vllm/nemotron_v3/config.py | 75 +++ .../speechlm2/vllm/nemotron_v3/model.py | 467 ++++++++++++++++++ pyproject.toml | 3 + 5 files changed, 582 insertions(+) create mode 100644 nemo/collections/speechlm2/vllm/__init__.py create mode 100644 nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py create mode 100644 nemo/collections/speechlm2/vllm/nemotron_v3/config.py create mode 100644 nemo/collections/speechlm2/vllm/nemotron_v3/model.py diff --git a/nemo/collections/speechlm2/vllm/__init__.py b/nemo/collections/speechlm2/vllm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py b/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py new file mode 100644 index 000000000000..83439e272a66 --- /dev/null +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py @@ -0,0 +1,37 @@ +_PKG = "nemo.collections.speechlm2.vllm.nemotron_v3" + + +def register(): + from nemo.collections.speechlm2.vllm.nemotron_v3.config import NeMoSpeechLMConfig + + from transformers import AutoConfig + AutoConfig.register("nemo_speechlm", NeMoSpeechLMConfig) + + from vllm.transformers_utils.config import _CONFIG_REGISTRY + _CONFIG_REGISTRY["nemo_speechlm"] = NeMoSpeechLMConfig + + from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model( + "NeMoSpeechLMForConditionalGeneration", + f"{_PKG}.model:NeMoSpeechLMForConditionalGeneration", + ) + + try: + from transformers import AutoConfig as _AC + _nhc = _AC.from_pretrained( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + trust_remote_code=True, + ) + NHConfigCls = type(_nhc) + _orig_getattr = getattr(NHConfigCls, '__getattr__', None) + + def _patched_getattr(self, name): + if name == 'rms_norm_eps': + return getattr(self, 'layer_norm_epsilon', 1e-5) + if _orig_getattr: + return _orig_getattr(self, name) + raise AttributeError(name) + + NHConfigCls.__getattr__ = _patched_getattr + except Exception: + pass diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/config.py b/nemo/collections/speechlm2/vllm/nemotron_v3/config.py new file mode 100644 index 000000000000..ac401ddb803d --- /dev/null +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/config.py @@ -0,0 +1,75 @@ +"""Configuration for NeMo Speech LM models in vLLM. + +Supports any combination of NeMo speech encoder + LLM backbone. +The checkpoint config.json defines which components to use. +""" + +from transformers import AutoConfig, PretrainedConfig + + +class NeMoSpeechLMConfig(PretrainedConfig): + model_type = "nemo_speechlm" + + def __init__( + self, + perception: dict | None = None, + pretrained_llm: str = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + pretrained_asr: str = "nvidia/canary-1b-v2", + audio_locator_tag: str = "<|audio|>", + prompt_format: str = "nemotron-nano-v3", + pretrained_weights: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.perception = perception or {} + self.pretrained_llm = pretrained_llm + self.pretrained_asr = pretrained_asr + self.audio_locator_tag = audio_locator_tag + self.prompt_format = prompt_format + self.pretrained_weights = pretrained_weights + + self.text_config = AutoConfig.from_pretrained( + pretrained_llm, trust_remote_code=True + ) + self.text_config.architectures = ["NemotronHForCausalLM"] + + if not hasattr(self.text_config, "total_num_kv_heads") or \ + self.text_config.total_num_kv_heads is None: + self.text_config.total_num_kv_heads = getattr( + self.text_config, "num_key_value_heads", 2 + ) + + if not hasattr(self.text_config, "rms_norm_eps"): + self.text_config.rms_norm_eps = getattr( + self.text_config, "layer_norm_epsilon", 1e-5 + ) + + self.text_config.vocab_size = self.text_config.vocab_size + 10 + + def get_text_config(self, decoder=False) -> PretrainedConfig: + return self.text_config + + _ATTR_ALIASES = { + "rms_norm_eps": "layer_norm_epsilon", + "layer_norm_eps": "layer_norm_epsilon", + } + + def __getattr__(self, name): + if name.startswith("_") or name in ( + "perception", "pretrained_llm", "pretrained_asr", + "audio_locator_tag", "prompt_format", "pretrained_weights", + "text_config", "_ATTR_ALIASES", + ): + raise AttributeError(name) + alias = self._ATTR_ALIASES.get(name, name) + try: + return getattr(self.text_config, alias) + except AttributeError: + if alias != name: + try: + return getattr(self.text_config, name) + except AttributeError: + pass + raise AttributeError( + f"'{type(self).__name__}' has no attribute '{name}'" + ) diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/model.py b/nemo/collections/speechlm2/vllm/nemotron_v3/model.py new file mode 100644 index 000000000000..aa32499e6760 --- /dev/null +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/model.py @@ -0,0 +1,467 @@ +"""Inference-only NeMo Speech LM model for vLLM. + +Architecture: NeMo speech encoder (e.g. FastConformer) + projection + LLM backbone. +Supports any combination of encoder and LLM defined by checkpoint config. +Requires NeMo toolkit for the audio encoder: + pip install nemo_toolkit[asr] +""" + +import re +from collections.abc import Iterable, Mapping +from contextlib import nullcontext +from typing import Annotated, Literal + +import torch +from torch import nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + IsHybrid, + MultiModalEmbeddings, + SupportsMambaPrefixCaching, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +logger = init_logger(__name__) + +_AUDIO_PLACEHOLDER = "<|audio|>" +_SAMPLING_RATE = 16000 +_MAX_AUDIO_DURATION_S = 40.0 + + +def _ensure_special_tokens(tokenizer): + special = [_AUDIO_PLACEHOLDER] + existing = set(tokenizer.get_vocab().keys()) + to_add = [t for t in special if t not in existing] + if to_add: + tokenizer.add_special_tokens({"additional_special_tokens": to_add}) + + +def _load_nemo_perception(perception_cfg: dict, output_dim: int) -> nn.Module: + try: + from nemo.collections.speechlm2.modules import AudioPerceptionModule + from omegaconf import DictConfig + except ImportError as e: + raise ImportError( + "NeMo is required for the audio encoder. " + "Install with: pip install nemo_toolkit[asr]" + ) from e + + cfg = DictConfig(perception_cfg) + if "output_dim" not in cfg: + cfg.output_dim = output_dim + perception = AudioPerceptionModule(cfg) + perception.eval() + return perception + + +class NeMoSpeechLMAudioInputs(TensorSchema): + type: Literal["audio_features"] = "audio_features" + audio_signal: Annotated[ + torch.Tensor | list[torch.Tensor], TensorShape("b", "t") + ] + audio_signal_length: Annotated[torch.Tensor, TensorShape("b")] + + +class NeMoSpeechLMProcessingInfo(BaseProcessingInfo): + + def get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser( + target_sr=_SAMPLING_RATE, + expected_hidden_size=self._get_expected_hidden_size(), + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_max_audio_tokens(self) -> int: + return self._estimate_audio_tokens(self.get_max_audio_len()) + + def get_max_audio_len(self) -> int: + return int(_MAX_AUDIO_DURATION_S * _SAMPLING_RATE) + + @staticmethod + def _estimate_audio_tokens(audio_length_samples: int) -> int: + n_fft = 512 + hop_length = 160 + stft_pad = n_fft // 2 + fbank_len = ( + (audio_length_samples + 2 * stft_pad - n_fft) // hop_length + ) + kernel, stride, repeat = 3, 2, 3 + add_pad = 1 + 1 - kernel + length = float(fbank_len) + for _ in range(repeat): + length = (length + add_pad) / stride + 1.0 + return max(1, int(length)) + + +class NeMoSpeechLMMultiModalProcessor( + BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo], +): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + audio_signal=MultiModalFieldConfig.batched("audio"), + audio_signal_length=MultiModalFieldConfig.batched("audio"), + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptUpdate]: + def get_replacement(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + n_tokens = self.info._estimate_audio_tokens(audio.shape[-1]) + repl_full = _AUDIO_PLACEHOLDER * n_tokens + return PromptUpdateDetails.select_text( + repl_full, _AUDIO_PLACEHOLDER + ) + + return [ + PromptReplacement( + modality="audio", + target=_AUDIO_PLACEHOLDER, + replacement=get_replacement, + ) + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + _ensure_special_tokens(tokenizer) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + if audios: + audio_list = [] + audio_lengths = [] + parts = re.split( + f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt + ) + audio_idx = 0 + for i, part in enumerate(parts): + if part == _AUDIO_PLACEHOLDER and audio_idx < len(audios): + audio = audios[audio_idx] + audio_tensor = ( + audio if isinstance(audio, torch.Tensor) + else torch.as_tensor(audio, dtype=torch.float32) + ) + if audio_tensor.dim() > 1: + audio_tensor = audio_tensor.squeeze() + n_tokens = self.info._estimate_audio_tokens( + audio_tensor.shape[-1] + ) + parts[i] = _AUDIO_PLACEHOLDER * n_tokens + audio_list.append(audio_tensor) + audio_lengths.append(audio_tensor.shape[-1]) + audio_idx += 1 + + prompt = "".join(parts) + + prompt_ids = tokenizer.encode(prompt, add_special_tokens=True) + result = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + if audios: + result["audio_signal"] = audio_list + result["audio_signal_length"] = torch.tensor(audio_lengths) + return result + + +class NeMoSpeechLMDummyInputsBuilder( + BaseDummyInputsBuilder[NeMoSpeechLMProcessingInfo], +): + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + return { + "audio": self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + ) + } + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + return "Transcribe the following: " + _AUDIO_PLACEHOLDER * num_audios + + +@MULTIMODAL_REGISTRY.register_processor( + NeMoSpeechLMMultiModalProcessor, + info=NeMoSpeechLMProcessingInfo, + dummy_inputs=NeMoSpeechLMDummyInputsBuilder, +) +class NeMoSpeechLMForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + IsHybrid, + SupportsMambaPrefixCaching, +): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("audio"): + return _AUDIO_PLACEHOLDER + return None + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config): + from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_dtype_from_config(vllm_config) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config): + from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_shape_from_config(vllm_config) + + @classmethod + def get_mamba_state_copy_func(cls): + from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_copy_func() + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["NemotronHForCausalLM"], + ) + + llm_hidden = config.text_config.hidden_size + + with self._mark_tower_model(vllm_config, {"audio"}): + self.perception = _load_nemo_perception( + config.perception, output_dim=llm_hidden + ) + self.perception = self.perception.to(torch.float32) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_audio_input( + self, **kwargs + ) -> NeMoSpeechLMAudioInputs | None: + audio_signal = kwargs.pop("audio_signal", None) + if audio_signal is None: + return None + audio_signal_length = kwargs.pop("audio_signal_length", None) + + if isinstance(audio_signal, list): + max_len = max(a.shape[-1] for a in audio_signal) + padded = [ + torch.nn.functional.pad(a, (0, max_len - a.shape[-1])) + for a in audio_signal + ] + audio_signal = torch.stack(padded, dim=0) + + if audio_signal_length is None: + audio_signal_length = torch.tensor( + [audio_signal.shape[-1]] * audio_signal.shape[0] + ) + elif not isinstance(audio_signal_length, torch.Tensor): + audio_signal_length = torch.tensor(audio_signal_length) + + return NeMoSpeechLMAudioInputs( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + + def _process_audio( + self, audio_input: NeMoSpeechLMAudioInputs + ) -> tuple[torch.Tensor, ...]: + device = next(self.perception.parameters()).device + self.perception = self.perception.to(device) + + audio_signal = audio_input.audio_signal + if isinstance(audio_signal, list): + audio_signal = torch.stack(audio_signal, dim=0) + audio_signal = audio_signal.to(device=device, dtype=torch.float32) + audio_lengths = audio_input.audio_signal_length.to(device=device) + + with torch.no_grad(): + audio_embeds, audio_embed_lens = self.perception( + input_signal=audio_signal, + input_signal_length=audio_lengths, + ) + + audio_embeds = audio_embeds.to(torch.bfloat16) + + return tuple( + audio_embeds[i, : audio_embed_lens[i]] + for i in range(audio_embeds.shape[0]) + ) + + def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: + audio_input = self._parse_audio_input(**kwargs) + if audio_input is None: + return [] + return self._process_audio(audio_input) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + return self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + def compute_logits( + self, hidden_states: torch.Tensor + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def get_mm_mapping(self) -> MultiModelKeys: + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="perception.proj", + tower_model="perception.encoder", + ) + + def _nemo_to_hf_llm_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + """Convert NeMo checkpoint weight names to HuggingFace NemotronH + format that vLLM's NemotronHForCausalLM.load_weights() expects.""" + for name, tensor in weights: + hf_name = name.replace("llm.model.", "backbone.") + hf_name = hf_name.replace("llm.lm_head", "lm_head") + if hf_name == "backbone.norm.weight": + hf_name = "backbone.norm_f.weight" + + if hf_name.endswith(".experts.down_projs"): + prefix = hf_name.replace(".experts.down_projs", "") + n_experts = tensor.shape[0] + for i in range(n_experts): + yield ( + f"{prefix}.experts.{i}.down_proj.weight", + tensor[i].t(), + ) + elif hf_name.endswith(".experts.gate_and_up_projs"): + prefix = hf_name.replace( + ".experts.gate_and_up_projs", "" + ) + n_experts = tensor.shape[0] + for i in range(n_experts): + yield ( + f"{prefix}.experts.{i}.up_proj.weight", + tensor[i].t(), + ) + elif hf_name in ("backbone.embed_tokens.weight", "lm_head.weight"): + target_vocab = getattr( + self.config.text_config, "vocab_size", tensor.shape[0] + ) + if tensor.shape[0] < target_vocab: + pad = torch.zeros( + target_vocab - tensor.shape[0], + *tensor.shape[1:], + dtype=tensor.dtype, + ) + tensor = torch.cat([tensor, pad], dim=0) + yield (hf_name, tensor) + else: + yield (hf_name, tensor) + + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> set[str]: + perception_weights = {} + perception_prefix = "perception." + llm_raw: list[tuple[str, torch.Tensor]] = [] + + for name, tensor in weights: + if "._extra_state" in name: + continue + if name.startswith(perception_prefix): + key = name[len(perception_prefix):] + perception_weights[key] = tensor + else: + llm_raw.append((name, tensor)) + + float32_weights = { + k: v.float() for k, v in perception_weights.items() + } + self.perception.load_state_dict(float32_weights, strict=False) + self.perception = self.perception.to(torch.float32) + loaded_perception = { + perception_prefix + k for k in perception_weights + } + + hf_weights = self._nemo_to_hf_llm_weights(llm_raw) + combined = ( + ("language_model." + n, t) for n, t in hf_weights + ) + + loader = AutoWeightsLoader(self) + loaded_llm = loader.load_weights(combined) + + return loaded_llm | loaded_perception diff --git a/pyproject.toml b/pyproject.toml index 4763f096e0c1..6e2100641a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ py-modules = ["nemo"] [project.entry-points."nemo_run.cli"] llm = "nemo.collections.llm" +[project.entry-points."vllm.general_plugins"] +nemo_speechlm = "nemo.collections.speechlm2.vllm.nemotron_v3:register" + [project.urls] Download = "https://github.com/NVIDIA/NeMo/releases" Homepage = "https://github.com/nvidia/nemo" From 6034fc198a9935656dda9f8f1011ae6f5f4eb892 Mon Sep 17 00:00:00 2001 From: Dongji Gao Date: Wed, 18 Mar 2026 10:56:46 -0700 Subject: [PATCH 2/3] Fix style: add license headers, docstrings, remove unused imports - Add Apache 2.0 license headers to all plugin files - Add docstrings to all public classes and register() - Remove unused imports (nullcontext, init_logger) - Run black (line_length=119) and isort formatting Signed-off-by: Dongji Gao --- .../speechlm2/vllm/nemotron_v3/__init__.py | 48 ++++++- .../speechlm2/vllm/nemotron_v3/config.py | 61 +++++--- .../speechlm2/vllm/nemotron_v3/model.py | 133 ++++++++---------- 3 files changed, 142 insertions(+), 100 deletions(-) diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py b/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py index 83439e272a66..73d7e61674d2 100644 --- a/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/__init__.py @@ -1,33 +1,73 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM plugin registration for NeMo Speech LM models. + +Registers NeMoSpeechLMConfig and NeMoSpeechLMForConditionalGeneration +into vLLM's model and config registries via the ``vllm.general_plugins`` +entry point. +""" + _PKG = "nemo.collections.speechlm2.vllm.nemotron_v3" def register(): - from nemo.collections.speechlm2.vllm.nemotron_v3.config import NeMoSpeechLMConfig + """Register the NeMo Speech LM model and config with vLLM. + Called automatically by vLLM when ``VLLM_PLUGINS=nemo_speechlm`` + is set, via the ``vllm.general_plugins`` entry point in + ``pyproject.toml``. + """ from transformers import AutoConfig + + from nemo.collections.speechlm2.vllm.nemotron_v3.config import NeMoSpeechLMConfig + AutoConfig.register("nemo_speechlm", NeMoSpeechLMConfig) from vllm.transformers_utils.config import _CONFIG_REGISTRY + _CONFIG_REGISTRY["nemo_speechlm"] = NeMoSpeechLMConfig from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model( "NeMoSpeechLMForConditionalGeneration", f"{_PKG}.model:NeMoSpeechLMForConditionalGeneration", ) + _apply_backend_patches() + + +def _apply_backend_patches(): + """Apply patches for LLM backends that need them. + + NemotronH's HF config uses ``layer_norm_epsilon`` but vLLM expects + ``rms_norm_eps``. This patches the config class at runtime. + """ try: from transformers import AutoConfig as _AC + _nhc = _AC.from_pretrained( "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", trust_remote_code=True, ) NHConfigCls = type(_nhc) - _orig_getattr = getattr(NHConfigCls, '__getattr__', None) + _orig_getattr = getattr(NHConfigCls, "__getattr__", None) def _patched_getattr(self, name): - if name == 'rms_norm_eps': - return getattr(self, 'layer_norm_epsilon', 1e-5) + if name == "rms_norm_eps": + return getattr(self, "layer_norm_epsilon", 1e-5) if _orig_getattr: return _orig_getattr(self, name) raise AttributeError(name) diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/config.py b/nemo/collections/speechlm2/vllm/nemotron_v3/config.py index ac401ddb803d..7f0a5d48b6fc 100644 --- a/nemo/collections/speechlm2/vllm/nemotron_v3/config.py +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/config.py @@ -1,13 +1,36 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Configuration for NeMo Speech LM models in vLLM. -Supports any combination of NeMo speech encoder + LLM backbone. -The checkpoint config.json defines which components to use. +Provides ``NeMoSpeechLMConfig``, a HuggingFace-compatible config class +that wraps the LLM backbone's text config with NeMo-specific fields +(perception, audio_locator_tag, etc.). The checkpoint's ``config.json`` +determines which LLM backbone and encoder are used. """ from transformers import AutoConfig, PretrainedConfig class NeMoSpeechLMConfig(PretrainedConfig): + """HuggingFace config for NeMo Speech LM multimodal models. + + Wraps a pretrained LLM config (e.g. NemotronH, Qwen3) with + additional fields for the speech perception module. The LLM + backbone config is loaded from ``pretrained_llm`` at init time. + """ + model_type = "nemo_speechlm" def __init__( @@ -28,25 +51,22 @@ def __init__( self.prompt_format = prompt_format self.pretrained_weights = pretrained_weights - self.text_config = AutoConfig.from_pretrained( - pretrained_llm, trust_remote_code=True - ) + self.text_config = AutoConfig.from_pretrained(pretrained_llm, trust_remote_code=True) self.text_config.architectures = ["NemotronHForCausalLM"] - if not hasattr(self.text_config, "total_num_kv_heads") or \ - self.text_config.total_num_kv_heads is None: - self.text_config.total_num_kv_heads = getattr( - self.text_config, "num_key_value_heads", 2 - ) + if not hasattr(self.text_config, "total_num_kv_heads") or self.text_config.total_num_kv_heads is None: + self.text_config.total_num_kv_heads = getattr(self.text_config, "num_key_value_heads", 2) if not hasattr(self.text_config, "rms_norm_eps"): - self.text_config.rms_norm_eps = getattr( - self.text_config, "layer_norm_epsilon", 1e-5 - ) + self.text_config.rms_norm_eps = getattr(self.text_config, "layer_norm_epsilon", 1e-5) + # Extend vocab to accommodate audio special tokens added at runtime. + # The embedding layer uses org_num_embeddings for weight loading + # so the checkpoint stays compatible. self.text_config.vocab_size = self.text_config.vocab_size + 10 def get_text_config(self, decoder=False) -> PretrainedConfig: + """Return the LLM backbone's text config.""" return self.text_config _ATTR_ALIASES = { @@ -56,9 +76,14 @@ def get_text_config(self, decoder=False) -> PretrainedConfig: def __getattr__(self, name): if name.startswith("_") or name in ( - "perception", "pretrained_llm", "pretrained_asr", - "audio_locator_tag", "prompt_format", "pretrained_weights", - "text_config", "_ATTR_ALIASES", + "perception", + "pretrained_llm", + "pretrained_asr", + "audio_locator_tag", + "prompt_format", + "pretrained_weights", + "text_config", + "_ATTR_ALIASES", ): raise AttributeError(name) alias = self._ATTR_ALIASES.get(name, name) @@ -70,6 +95,4 @@ def __getattr__(self, name): return getattr(self.text_config, name) except AttributeError: pass - raise AttributeError( - f"'{type(self).__name__}' has no attribute '{name}'" - ) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") diff --git a/nemo/collections/speechlm2/vllm/nemotron_v3/model.py b/nemo/collections/speechlm2/vllm/nemotron_v3/model.py index aa32499e6760..a157a83b40c6 100644 --- a/nemo/collections/speechlm2/vllm/nemotron_v3/model.py +++ b/nemo/collections/speechlm2/vllm/nemotron_v3/model.py @@ -1,23 +1,33 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Inference-only NeMo Speech LM model for vLLM. -Architecture: NeMo speech encoder (e.g. FastConformer) + projection + LLM backbone. -Supports any combination of encoder and LLM defined by checkpoint config. -Requires NeMo toolkit for the audio encoder: - pip install nemo_toolkit[asr] +Architecture: NeMo speech encoder (e.g. FastConformer) + projection + LLM backbone +(e.g. NemotronH). Requires NeMo toolkit for the audio encoder: +``pip install nemo_toolkit[asr]`` """ import re from collections.abc import Iterable, Mapping -from contextlib import nullcontext from typing import Annotated, Literal import torch from torch import nn from transformers import BatchFeature - from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( IsHybrid, MultiModalEmbeddings, @@ -53,8 +63,6 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -logger = init_logger(__name__) - _AUDIO_PLACEHOLDER = "<|audio|>" _SAMPLING_RATE = 16000 _MAX_AUDIO_DURATION_S = 40.0 @@ -70,12 +78,12 @@ def _ensure_special_tokens(tokenizer): def _load_nemo_perception(perception_cfg: dict, output_dim: int) -> nn.Module: try: - from nemo.collections.speechlm2.modules import AudioPerceptionModule from omegaconf import DictConfig + + from nemo.collections.speechlm2.modules import AudioPerceptionModule except ImportError as e: raise ImportError( - "NeMo is required for the audio encoder. " - "Install with: pip install nemo_toolkit[asr]" + "NeMo is required for the audio encoder. " "Install with: pip install nemo_toolkit[asr]" ) from e cfg = DictConfig(perception_cfg) @@ -87,14 +95,15 @@ def _load_nemo_perception(perception_cfg: dict, output_dim: int) -> nn.Module: class NeMoSpeechLMAudioInputs(TensorSchema): + """Typed schema for audio inputs passed through vLLM's multimodal pipeline.""" + type: Literal["audio_features"] = "audio_features" - audio_signal: Annotated[ - torch.Tensor | list[torch.Tensor], TensorShape("b", "t") - ] + audio_signal: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("b", "t")] audio_signal_length: Annotated[torch.Tensor, TensorShape("b")] class NeMoSpeechLMProcessingInfo(BaseProcessingInfo): + """Processing info for NeMo Speech LM: audio token estimation and limits.""" def get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser( @@ -116,9 +125,7 @@ def _estimate_audio_tokens(audio_length_samples: int) -> int: n_fft = 512 hop_length = 160 stft_pad = n_fft // 2 - fbank_len = ( - (audio_length_samples + 2 * stft_pad - n_fft) // hop_length - ) + fbank_len = (audio_length_samples + 2 * stft_pad - n_fft) // hop_length kernel, stride, repeat = 3, 2, 3 add_pad = 1 + 1 - kernel length = float(fbank_len) @@ -130,6 +137,7 @@ def _estimate_audio_tokens(audio_length_samples: int) -> int: class NeMoSpeechLMMultiModalProcessor( BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo], ): + """Multimodal processor that handles audio tokenization and prompt expansion.""" def _get_mm_fields_config( self, @@ -161,9 +169,7 @@ def get_replacement(item_idx: int): audio = audios.get(item_idx) n_tokens = self.info._estimate_audio_tokens(audio.shape[-1]) repl_full = _AUDIO_PLACEHOLDER * n_tokens - return PromptUpdateDetails.select_text( - repl_full, _AUDIO_PLACEHOLDER - ) + return PromptUpdateDetails.select_text(repl_full, _AUDIO_PLACEHOLDER) return [ PromptReplacement( @@ -188,22 +194,17 @@ def _call_hf_processor( if audios: audio_list = [] audio_lengths = [] - parts = re.split( - f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt - ) + parts = re.split(f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt) audio_idx = 0 for i, part in enumerate(parts): if part == _AUDIO_PLACEHOLDER and audio_idx < len(audios): audio = audios[audio_idx] audio_tensor = ( - audio if isinstance(audio, torch.Tensor) - else torch.as_tensor(audio, dtype=torch.float32) + audio if isinstance(audio, torch.Tensor) else torch.as_tensor(audio, dtype=torch.float32) ) if audio_tensor.dim() > 1: audio_tensor = audio_tensor.squeeze() - n_tokens = self.info._estimate_audio_tokens( - audio_tensor.shape[-1] - ) + n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1]) parts[i] = _AUDIO_PLACEHOLDER * n_tokens audio_list.append(audio_tensor) audio_lengths.append(audio_tensor.shape[-1]) @@ -223,6 +224,7 @@ def _call_hf_processor( class NeMoSpeechLMDummyInputsBuilder( BaseDummyInputsBuilder[NeMoSpeechLMProcessingInfo], ): + """Builds dummy audio inputs for vLLM's model profiling and warmup.""" def get_dummy_mm_data( self, @@ -255,6 +257,12 @@ class NeMoSpeechLMForConditionalGeneration( IsHybrid, SupportsMambaPrefixCaching, ): + """NeMo Speech LM model for vLLM inference. + + Combines a NeMo speech encoder (AudioPerceptionModule) with a vLLM-native + LLM backbone (e.g. NemotronH). Audio is encoded to embeddings that replace + placeholder tokens in the text sequence before LLM decoding. + """ @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: @@ -265,16 +273,19 @@ def get_placeholder_str(cls, modality: str, i: int) -> str | None: @classmethod def get_mamba_state_dtype_from_config(cls, vllm_config): from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_dtype_from_config(vllm_config) @classmethod def get_mamba_state_shape_from_config(cls, vllm_config): from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_shape_from_config(vllm_config) @classmethod def get_mamba_state_copy_func(cls): from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM + return NemotronHForCausalLM.get_mamba_state_copy_func() def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -293,18 +304,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): llm_hidden = config.text_config.hidden_size with self._mark_tower_model(vllm_config, {"audio"}): - self.perception = _load_nemo_perception( - config.perception, output_dim=llm_hidden - ) + self.perception = _load_nemo_perception(config.perception, output_dim=llm_hidden) self.perception = self.perception.to(torch.float32) - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors - def _parse_audio_input( - self, **kwargs - ) -> NeMoSpeechLMAudioInputs | None: + def _parse_audio_input(self, **kwargs) -> NeMoSpeechLMAudioInputs | None: audio_signal = kwargs.pop("audio_signal", None) if audio_signal is None: return None @@ -312,16 +317,11 @@ def _parse_audio_input( if isinstance(audio_signal, list): max_len = max(a.shape[-1] for a in audio_signal) - padded = [ - torch.nn.functional.pad(a, (0, max_len - a.shape[-1])) - for a in audio_signal - ] + padded = [torch.nn.functional.pad(a, (0, max_len - a.shape[-1])) for a in audio_signal] audio_signal = torch.stack(padded, dim=0) if audio_signal_length is None: - audio_signal_length = torch.tensor( - [audio_signal.shape[-1]] * audio_signal.shape[0] - ) + audio_signal_length = torch.tensor([audio_signal.shape[-1]] * audio_signal.shape[0]) elif not isinstance(audio_signal_length, torch.Tensor): audio_signal_length = torch.tensor(audio_signal_length) @@ -330,9 +330,7 @@ def _parse_audio_input( audio_signal_length=audio_signal_length, ) - def _process_audio( - self, audio_input: NeMoSpeechLMAudioInputs - ) -> tuple[torch.Tensor, ...]: + def _process_audio(self, audio_input: NeMoSpeechLMAudioInputs) -> tuple[torch.Tensor, ...]: device = next(self.perception.parameters()).device self.perception = self.perception.to(device) @@ -350,10 +348,7 @@ def _process_audio( audio_embeds = audio_embeds.to(torch.bfloat16) - return tuple( - audio_embeds[i, : audio_embed_lens[i]] - for i in range(audio_embeds.shape[0]) - ) + return tuple(audio_embeds[i, : audio_embed_lens[i]] for i in range(audio_embeds.shape[0])) def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: audio_input = self._parse_audio_input(**kwargs) @@ -371,13 +366,9 @@ def forward( ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - return self.language_model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) + return self.language_model(input_ids, positions, intermediate_tensors, inputs_embeds) - def compute_logits( - self, hidden_states: torch.Tensor - ) -> torch.Tensor | None: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def get_mm_mapping(self) -> MultiModelKeys: @@ -407,9 +398,7 @@ def _nemo_to_hf_llm_weights( tensor[i].t(), ) elif hf_name.endswith(".experts.gate_and_up_projs"): - prefix = hf_name.replace( - ".experts.gate_and_up_projs", "" - ) + prefix = hf_name.replace(".experts.gate_and_up_projs", "") n_experts = tensor.shape[0] for i in range(n_experts): yield ( @@ -417,9 +406,7 @@ def _nemo_to_hf_llm_weights( tensor[i].t(), ) elif hf_name in ("backbone.embed_tokens.weight", "lm_head.weight"): - target_vocab = getattr( - self.config.text_config, "vocab_size", tensor.shape[0] - ) + target_vocab = getattr(self.config.text_config, "vocab_size", tensor.shape[0]) if tensor.shape[0] < target_vocab: pad = torch.zeros( target_vocab - tensor.shape[0], @@ -431,9 +418,7 @@ def _nemo_to_hf_llm_weights( else: yield (hf_name, tensor) - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: perception_weights = {} perception_prefix = "perception." llm_raw: list[tuple[str, torch.Tensor]] = [] @@ -442,24 +427,18 @@ def load_weights( if "._extra_state" in name: continue if name.startswith(perception_prefix): - key = name[len(perception_prefix):] + key = name[len(perception_prefix) :] perception_weights[key] = tensor else: llm_raw.append((name, tensor)) - float32_weights = { - k: v.float() for k, v in perception_weights.items() - } + float32_weights = {k: v.float() for k, v in perception_weights.items()} self.perception.load_state_dict(float32_weights, strict=False) self.perception = self.perception.to(torch.float32) - loaded_perception = { - perception_prefix + k for k in perception_weights - } + loaded_perception = {perception_prefix + k for k in perception_weights} hf_weights = self._nemo_to_hf_llm_weights(llm_raw) - combined = ( - ("language_model." + n, t) for n, t in hf_weights - ) + combined = (("language_model." + n, t) for n, t in hf_weights) loader = AutoWeightsLoader(self) loaded_llm = loader.load_weights(combined) From b260d275af62537990626d1643cdac43fa0939ca Mon Sep 17 00:00:00 2001 From: Dongji Gao Date: Wed, 18 Mar 2026 12:01:59 -0700 Subject: [PATCH 3/3] Add unit tests for vLLM NeMo Speech LM plugin Tests config creation, special token handling, plugin registration, and audio encoder forward pass with dummy audio. No model weights required; GPU needed only for the perception forward test. All 10 tests pass. Signed-off-by: Dongji Gao --- .../collections/speechlm2/test_vllm_plugin.py | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 tests/collections/speechlm2/test_vllm_plugin.py diff --git a/tests/collections/speechlm2/test_vllm_plugin.py b/tests/collections/speechlm2/test_vllm_plugin.py new file mode 100644 index 000000000000..423ae95883a7 --- /dev/null +++ b/tests/collections/speechlm2/test_vllm_plugin.py @@ -0,0 +1,184 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the vLLM NeMo Speech LM plugin. + +Tests plugin registration, config loading, and special token handling +without requiring GPU or model weights. +""" + +import pytest + +try: + from nemo.collections.speechlm2.vllm.nemotron_v3.config import NeMoSpeechLMConfig + + _HAS_CONFIG = True +except (ImportError, RuntimeError): + _HAS_CONFIG = False + +try: + import vllm # noqa: F401 + + _HAS_VLLM = True +except (ImportError, RuntimeError): + _HAS_VLLM = False + + +@pytest.mark.skipif(not _HAS_CONFIG, reason="NeMoSpeechLMConfig not available") +class TestNeMoSpeechLMConfig: + """Tests for NeMoSpeechLMConfig.""" + + def test_model_type(self): + assert NeMoSpeechLMConfig.model_type == "nemo_speechlm" + + def test_loads_text_config(self): + """Config should load a text_config from the pretrained LLM.""" + cfg = NeMoSpeechLMConfig() + assert cfg.text_config is not None + assert hasattr(cfg.text_config, "hidden_size") + assert cfg.get_text_config() is cfg.text_config + + def test_custom_pretrained_llm(self): + """Config should accept different LLM backbones.""" + cfg = NeMoSpeechLMConfig(pretrained_llm="Qwen/Qwen3-1.7B") + assert cfg.pretrained_llm == "Qwen/Qwen3-1.7B" + assert cfg.text_config is not None + + def test_audio_locator_tag_configurable(self): + cfg = NeMoSpeechLMConfig(audio_locator_tag="<|custom_audio|>") + assert cfg.audio_locator_tag == "<|custom_audio|>" + + def test_unknown_attr_raises(self): + cfg = NeMoSpeechLMConfig() + with pytest.raises(AttributeError): + _ = cfg.nonexistent_attribute_xyz + + +@pytest.mark.skipif(not _HAS_VLLM, reason="vLLM not installed") +class TestSpecialTokens: + """Tests for special token handling.""" + + def test_adds_missing_token(self): + from unittest.mock import MagicMock + + from nemo.collections.speechlm2.vllm.nemotron_v3.model import _ensure_special_tokens + + tokenizer = MagicMock() + tokenizer.get_vocab.return_value = {} + _ensure_special_tokens(tokenizer) + tokenizer.add_special_tokens.assert_called_once() + + def test_skips_existing_token(self): + from unittest.mock import MagicMock + + from nemo.collections.speechlm2.vllm.nemotron_v3.model import _ensure_special_tokens + + tokenizer = MagicMock() + tokenizer.get_vocab.return_value = {"<|audio|>": 99} + _ensure_special_tokens(tokenizer) + tokenizer.add_special_tokens.assert_not_called() + + +@pytest.mark.skipif(not _HAS_VLLM, reason="vLLM not installed") +class TestAudioProcessing: + """Tests for audio encoding with a tiny perception module.""" + + def test_perception_forward(self): + import torch + + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + """A small NeMo perception module should encode dummy audio to embeddings.""" + from nemo.collections.speechlm2.vllm.nemotron_v3.model import _load_nemo_perception + + perception_cfg = { + "output_dim": 256, + "encoder": { + "_target_": "nemo.collections.asr.modules.ConformerEncoder", + "feat_in": 128, + "feat_out": -1, + "n_layers": 2, + "d_model": 256, + "subsampling": "dw_striding", + "subsampling_factor": 8, + "subsampling_conv_channels": 64, + "ff_expansion_factor": 4, + "self_attention_model": "rel_pos", + "n_heads": 4, + "conv_kernel_size": 9, + "conv_norm_type": "batch_norm", + "dropout": 0.0, + "dropout_pre_encoder": 0.0, + "dropout_emb": 0.0, + "dropout_att": 0.0, + }, + "modality_adapter": { + "_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector", + "d_model": 256, + }, + "preprocessor": { + "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", + "sample_rate": 16000, + "normalize": "per_feature", + "window_size": 0.025, + "window_stride": 0.01, + "window": "hann", + "features": 128, + "n_fft": 512, + "log": True, + "frame_splicing": 1, + "dither": 0.0, + "pad_to": 0, + "pad_value": 0.0, + }, + } + + perception = _load_nemo_perception(perception_cfg, output_dim=256) + perception = perception.to("cuda", dtype=torch.float32) + + dummy_audio = torch.randn(1, 16000, device="cuda") + audio_len = torch.tensor([16000], device="cuda") + + with torch.no_grad(): + embeds, embed_lens = perception(input_signal=dummy_audio, input_signal_length=audio_len) + + assert embeds.ndim == 3 + assert embeds.shape[0] == 1 + assert embeds.shape[2] == 256 + assert embed_lens[0] > 0 + + +@pytest.mark.skipif(not _HAS_VLLM, reason="vLLM not installed") +class TestPluginRegistration: + """Tests for plugin registration with vLLM.""" + + def test_register_config(self): + """register() should add nemo_speechlm to vLLM's config registry.""" + from nemo.collections.speechlm2.vllm.nemotron_v3 import register + + register() + + from vllm.transformers_utils.config import _CONFIG_REGISTRY + + assert "nemo_speechlm" in _CONFIG_REGISTRY + + def test_register_model(self): + """register() should make NeMoSpeechLMForConditionalGeneration importable.""" + from nemo.collections.speechlm2.vllm.nemotron_v3 import register + + register() + + from nemo.collections.speechlm2.vllm.nemotron_v3.model import NeMoSpeechLMForConditionalGeneration + + assert NeMoSpeechLMForConditionalGeneration is not None