Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Jan 12, 2026

Summary

Adds vLLM-optimized Apriel2 model implementation to fast_llm_external_models

  • Uses vLLM's ModelRegistry.register_model() for runtime registration (no vLLM patching required)
  • Supports hybrid architectures: attention, mamba, GDN, and KDA mixers

Attribution

Model implementation based on work by @nandahkrishna from the apriel2-vllm branch. This PR adapts that implementation for plugin-based registration as an alternative to patching vLLM directly.

Goal

Evaluate whether vLLM's plugin/registration mechanism can work for us as a short-term solution, avoiding the need to maintain a patched vLLM fork.

Usage

from fast_llm_external_models.apriel2.vllm import register
from vllm import LLM

register()
llm = LLM(model="path/to/apriel2/checkpoint")

vLLM vs Transformers Alignment Verification

Statistical comparison using test_apriel2.py stats command with:

  • 64 prompts from C4 dataset (deterministic sampling with seed=42)
  • 128 tokens prompt length
  • 16 decode steps
  • Identical token IDs sent to both backends (controlled tokenization)
  • Per-position logprob comparison with percentile statistics

Models Tested

Model Description
pure-gdn 100% GDN layers
attn-swa 100% attention (sliding window)
every5th-kda 80% attention + 20% KDA

Results Summary

Model Mode Match% Mean Diff p95 Diff Max Diff Outliers
GDN no-compile 84.4% 1.05 7.83 22.80 142 (14.0%)
GDN compiled 83.4% 1.07 7.87 13.99 155 (15.4%)
SWA no-compile 87.5% 0.83 7.55 18.59 111 (10.8%)
SWA compiled 84.7% 1.14 8.85 23.38 140 (13.7%)
KDA no-compile 87.1% 0.83 7.30 15.70 120 (11.7%)
KDA compiled 84.8% 1.03 8.39 15.15 141 (13.8%)

Per-Position Token Match Rate (no-compile mode)

Position GDN SWA KDA
prefill 95.3% 96.9% 98.4%
decode1 96.9% 93.8% 95.3%
decode2 92.2% 93.8% 89.1%
decode3 92.2% 92.2% 90.6%
decode4 90.6% 92.2% 90.6%
decode5 90.6% 90.6% 87.5%
decode6 85.7% 93.8% 89.1%
decode7 85.7% 89.1% 89.1%
decode8 84.1% 87.5% 89.1%
decode9 79.4% 82.8% 85.9%
decode10 81.0% 85.9% 84.4%
decode11 79.4% 84.4% 82.8%
decode12 73.0% 81.2% 82.8%
decode13 74.6% 79.7% 82.8%
decode14 74.2% 79.7% 78.1%
decode15 73.8% 76.6% 78.1%

Key Findings

1. Divergence is NOT mixer-specific

All models (GDN, SWA, KDA) show similar divergence patterns between vLLM and Transformers. This indicates the issue is in shared model code (RMSNorm, MLP, embeddings) rather than mixer implementations.

2. torch.compile has minimal impact

Compile vs no-compile produces nearly identical results:

  • GDN: 84.4% vs 83.4% match
  • SWA: 87.5% vs 84.7% match
  • KDA: 87.1% vs 84.8% match

Previous reports of GDN torch.compile issues appear to have been measurement artifacts.

3. Divergence accumulates over decode steps

  • Prefill: 95-98% token match rate
  • Decode15: 73-78% token match rate

Small numerical differences compound during autoregressive generation, causing progressive divergence.

4. Prefill is well-aligned

All models show excellent prefill alignment (95-98% match, avg diff ~0.04), making them reliable for likelihood-based evaluation (MMLU, etc.).


Implications

For likelihood-based evaluation (MMLU)

All models reliable - prefill-only evaluation shows 95-98% alignment

For generative evaluation (GSM8K)

⚠️ All models show accumulating divergence - vLLM and Transformers will produce different outputs over long generations, regardless of mixer type or compilation mode

Root Cause Investigation Needed

The divergence affects all model types equally, suggesting the issue is in:

  • RMSNorm implementation differences
  • MLP/SwiGLU numerical precision
  • Embedding layer handling
  • KV cache management differences

Test Configuration

# Run statistical comparison
python test_apriel2.py stats /path/to/model \
    --num-prompts 64 \
    --prompt-length 128 \
    --decode-length 16 \
    --batch-size 1 \
    --dtype bfloat16 \
    --tf-kernels vllm \
    [--no-compile]  # Add for no-compile mode

Test plan

  • Test registration mechanism with vLLM
  • Verify model loads correctly
  • Statistical comparison of vLLM vs Transformers (GDN, SWA, KDA)
  • Tested compile vs no-compile modes
  • Per-position analysis of divergence patterns
  • Investigate shared code divergence (RMSNorm, MLP, embeddings)

🤖 Generated with Claude Code

tscholak and others added 17 commits January 10, 2026 12:38
- Add README.md documenting the algebraic structure of the conversion system
  (surgery monoid, action law, plan composition, total vs partial operations)
- Add prune_supernet_step1.yaml and prune_supernet_step2.yaml examples
  demonstrating the two-step workflow for pruning a homogeneous supernet
  to a heterogeneous network with different mixer types per layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add modeling_apriel2.py with full vLLM-optimized implementation
  supporting attention, mamba, GDN, and KDA mixer types
- Add register() function for runtime model registration via
  vLLM's ModelRegistry (no patching required)
- Based on Nanda's vllm_diff.patch, adapted for external package use

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Refactor weight loading: each mixer module (Attention, MLP, GDN, KDA)
  now handles its own weight structure via load_weights() methods
- Fix KDA mamba_type to use "gdn_attention" for vLLM backend registration
- Add KDA op registration import for custom op support
- Remove unused positions parameter from KDA forward
- Add config_convertor.py for Apriel2TextConfig to vLLM config mapping
- Add test_apriel2.py for coherence and logit comparison testing
  between vLLM and Transformers implementations

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove all PyTorch fallback implementations to ensure fast CUDA kernels
are always used. The module now fails loudly at import/instantiation
if required kernels are missing.

Changes:
- Remove torch_causal_conv1d_fn and torch_causal_conv1d_update fallbacks
- Remove torch_selective_scan_fn and torch_selective_state_update stubs
- Remove torch_chunk_gated_delta_rule function
- Remove _recurrent_gated_delta_rule method from Apriel2GatedDeltaNet
- Remove _forward_local method from GatedRMSNormalization
- Remove TestFastVsSlowPath test class (no longer needed)
- Handle CausalConv1d seq_len==1 edge case via update() instead of fallback
- Add ImportError at module load for missing causal_conv1d/mamba_ssm
- Add ImportError at class init for missing FLA kernels

Required packages:
- causal_conv1d (for CausalConv1d)
- mamba_ssm (for Mamba/SSM operations)
- fla (for GDN, KDA, GatedRMSNormalization)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The chunk_gated_delta_rule call was always passing initial_state=None,
ignoring any existing recurrent state from previous decode cycles.
This broke continued generation scenarios (prefill -> decode -> prefill).

Changed initial_state=None to initial_state=recurrent_state to match
the correct behavior already present in KDA's chunk_kda call.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add test_vs_qwen3next_with_cache and test_vs_fla_with_cache tests that
verify mixer implementations through all inference phases:
- Phase 1: Initial prefill with cache population
- Phase 2: Single-token decode using cached states
- Phase 3: Prefill again (decode→prefill transition)

Tests compare outputs and recurrent states at each phase. Convolution
states are not compared due to different storage formats between
implementations (Apriel2 stores kernel_size-1, references store
kernel_size).

For GDN, Phase 3 documents expected divergence from Qwen3Next due to
its bug where chunk mode ignores initial_state.

For KDA, all phases should match since FLA correctly passes
initial_state in chunk mode.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Merge test_vs_qwen3next and test_vs_qwen3next_with_cache into single
  parameterized test with use_cache fixture
- Merge test_vs_fla and test_vs_fla_with_cache similarly
- Add use_cache (False/True) and decode_steps (4) fixtures
- Use proper Apriel2Cache from cache.py instead of ad-hoc SimpleCache
- Use same total sequence length for both cache and non-cache modes
- Skip cache tests when seq_len < decode_steps + 2 (too small for 3 phases)
- Split sequence as: prefill=2/3, decode=4, prefill2=1/3 of remaining

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix KDA mode selection to match FLA: use fused_recurrent only when
  seq_len <= 64 AND not training (single expression instead of override)
- Replace use_cache fixture with explicit phase fixtures (prefill_len,
  decode_steps, prefill2_len) for clearer test parameterization
- Update test_chunked_vs_recurrent to use Apriel2Cache and fixtures
- Rename config_dict to mixer_config for consistency across all tests
- Remove unused qwen3_config fixture (recreated inline where needed)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
CausalConv1d is now tested through KDA equivalence tests which use
CausalConv1d for q_conv, k_conv, v_conv. The isolated tests were also
obsolete since CPU fallback was removed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move all cache classes (_AttentionCache, _SSMCache, _DummyCacheLayer,
Apriel2Cache, _LayerListAccessor) into modeling_apriel2.py for better
tooling compatibility - modeling code is expected to be together.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enable "fast" mode (bf16/sdpa) tests that were previously skipped
- Add test_dtype fixture parameter to all tests that create models
- Convert models to correct dtype with .to(device="cuda", dtype=test_dtype)
- Create input tensors with explicit dtype parameter
- Fix assert_close to cast tensors to same dtype before comparison

All 1718 mixer equivalence tests now pass in both fp32 and bf16 modes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add gdn_mixer_config and kda_mixer_config fixtures to centralize
  mixer config dict construction (eliminates 6 duplicate dicts)
- Add kda_hidden_size fixture for derived hidden_size calculation
- Add make_apriel2_config() helper for minimal Apriel2TextConfig
  construction (eliminates 4 duplicate config blocks)
- Update all GDN and KDA tests to use new fixtures
- Consolidate duplicate imports within test methods

Net reduction: 47 lines (-125/+78)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tscholak tscholak changed the base branch from main to fix/require-cuda-kernels-no-fallbacks January 18, 2026 22:59
oleksost and others added 3 commits January 19, 2026 14:15
- Fix rope_theta parameter: use 'rope_theta' key instead of 'base' in
  get_rope() call. This fixes attention alignment (0.002 fp32 / 0.05 bf16)
- Switch GDN from qwen3_fused_gdn_gating to fused_gdn_gating
- Add commented-out GQA head expansion code for GDN (WIP)
- Add dtype parameter to test_apriel2.py for bf16/fp32 comparison
- Use flash_attention_2 for bf16 transformers to match vLLM backend

Current alignment status:
- attn-swa: ✅ MATCH (0.002 fp32 / 0.05 bf16)
- KDA: ✅ MATCH (0.003 fp32 / 0.07 bf16)
- GDN: ❌ MISMATCH (14.6 - investigation ongoing)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Base automatically changed from fix/require-cuda-kernels-no-fallbacks to oo/apriel_modeling_bug January 19, 2026 14:42
Base automatically changed from oo/apriel_modeling_bug to main January 19, 2026 14:50
tscholak and others added 6 commits January 19, 2026 14:51
…sigmoid

The vLLM KDA implementation was hardcoding activation="sigmoid" for the
output normalization, while the transformers implementation defaults to
"silu" when not specified in config. This caused significant logprob
differences (avg 1.1) between vLLM and transformers.

Now reads norm_activation from mixer_config.normalization.activation
with default "silu" to match transformers behavior.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes to transformers model (modeling_apriel2.py):
- Add USE_VLLM_CONV, USE_VLLM_GDN_OPS, USE_VLLM_GATED_NORM flags
- Restructure kernel imports to use vLLM ops when flags enabled
- Add _debug_enabled, _debug_layer, _debug_final flags for debugging
- Handle vLLM vs FLA signature differences for fused_recurrent_gated_delta_rule

Changes to vLLM model (vllm/modeling_apriel2.py):
- Add _debug_enabled, _debug_layer flags for GDN mixer
- Add _debug_final, _debug_lm_head flags for final norm and LM head
- Gate debug prints with boolean flags instead of num_tokens checks

Changes to test script (vllm/test_apriel2.py):
- Add comprehensive comparison command for vLLM vs TF logprob testing
- Test across prompt sizes, decode lengths, and batch sizes

Results: Prefill logprobs now match perfectly between vLLM and TF
when using vLLM kernels (USE_VLLM_GDN_OPS=True, USE_VLLM_GATED_NORM=True).
Some divergence remains during multi-token decode for certain prompt lengths.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add _debug_state flag and _debug_state_stats() method to both TF and
vLLM GDN mixer classes to track recurrent state evolution during
prefill and decode phases.

Key additions:
- TF: Debug state after prefill and during decode for layer 1
- vLLM: Debug state with correct slot indexing for decode phase
- Print state statistics (mean, std, min, max, first8 values)

This helps investigate the decode divergence at specific prompt lengths
(50, 51, 59, 60, 70 tokens) where vLLM and TF produce different results.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add pure_gdn_step1.yaml: converts fixed -> pattern with all GDN blocks
- Add pure_gdn_step2.yaml: unwraps stochastic -> pure GDN mixer
- Improve TF GDN debug logging with try/except for tensor access
- Add vLLM GDN debug output logging during decode phase
- Add first mismatch details in test_apriel2.py compare output

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
tscholak and others added 15 commits January 21, 2026 16:10
Replace scattered class-level and function-local debug flags with
top-level DEBUG_* constants for easier control:

- DEBUG_GDN_LAYER: GDN layer forward pass (tensors, shapes)
- DEBUG_GDN_STATE: GDN recurrent state during decode
- DEBUG_GDN_OUTPUT: GDN output hidden states during decode
- DEBUG_KDA_LAYER: KDA layer outputs
- DEBUG_DECODER_LAYER: Decoder layer outputs (residual, norm)
- DEBUG_FINAL_NORM: Final norm before LM head
- DEBUG_LM_HEAD: LM head input/output

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Always call repeat_interleave for K→V head expansion (no-op when
value_heads_per_key == 1) to avoid conditional branches that confuse
torch.compile's shape inference.

Also temporarily comment out compilation_config in test script while
investigating hybrid model compilation issues.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Also keep USE_VLLM_* flags at False for upstream kernel testing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change AttentionDecoderLayer.forward signature: move positions to optional kwarg
- All layers now accept (hidden_states, residual, positions=None, **kwargs)
- Remove isinstance dispatch in Apriel2Model.forward loop
- Call all layer types uniformly with same arguments

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Match Llama's approach: use torch._check to assert relationship between
positions and input_ids sizes without hardcoding values. This helps the
compiler understand dynamic shapes during chunked prefill warmup.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Debug code with f-strings (e.g., f"num_tokens={num_tokens}") caused
torch.compile to fail with ConstraintViolationError because f-strings
are evaluated before the function call, causing tensor.size() calls
to be traced even when debug flags are False.

Also commented out debug-related code that converts tensor values to
Python integers (e.g., int(tensor[0])) which breaks CUDA graph capture.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add 'stats' command for rigorous vLLM vs Transformers comparison
- Use C4 dataset for reproducible, diverse prompts
- Controlled tokenization: same token IDs to both backends via TokensPrompt
- Per-position statistics (prefill + each decode step)
- Percentile-based analysis (p10, p50, p90, p95, p99)
- Outlier detection and reporting
- Configurable: num_prompts, prompt_length, decode_length, tf_kernels, seed
- Fix --no-compile argparse bug in compare command

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implements support for loading stochastic mixer models directly in vLLM
without conversion. Key changes:

- Add Apriel2StochasticMixer class that contains all sub-mixers and
  routes inputs to the active mixer at runtime
- Add Apriel2StochasticDecoderLayer for stochastic decoder blocks
- Implement "convex hull" page size computation that considers ALL
  sub-mixer types to ensure unified page size fits any mixer
- Use virtual layer indices (Falcon H1 style) to give each sub-mixer
  type its own cache allocation without conflicts
- Add test_loading.py for testing model loading without generation

The stochastic mixer allocates caches for all mixer types, enabling
future runtime mixer switching capability.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract _create_mixer_params helper to eliminate ~90 lines of duplication
  in get_block_params for stochastic mixer handling
- Fix MIXER_TYPE_OFFSETS bug: use mixer index instead of type to prevent
  collisions when multiple mixers share the same type (e.g., attention and
  sliding_window both have type "attention")
- Remove dead class-level get_kv_cache_spec method (vLLM calls instance
  methods on each layer, not the class-level method)
- Remove unused get_block_specs and get_block_name_for_layer functions

Net reduction of ~200 lines.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Cache get_unified_page_size_for_config results by object identity.
This avoids redundant computation when vLLM calls each layer's
get_kv_cache_spec independently (96 calls → 1 for 24-layer model
with 4 stochastic sub-mixers).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…tching

All mixers now use the vLLM-standard signature:
  forward(hidden_states, output, positions=None, **kwargs) -> None

This enables runtime placement switching between mixer types (attention,
gdn, kda, mamba) via collective_rpc without signature mismatches.

Changes:
- Apriel2Attention: write to output buffer instead of returning
- Apriel2MambaMixer/GDN/KDA: add positions parameter for uniformity
- Apriel2AttentionDecoderLayer: allocate buffer and pass to mixer
- Apriel2StochasticMixer: delegate to active mixer with unified signature
- Add worker monkey-patching for collective_rpc placement methods
- Add test_placement_comparison.py to validate output equivalence

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add entry_points in setup.cfg for automatic vLLM plugin registration
- Consolidate model registration to config_convertor.py as single source
- Add --placement option to test_apriel2.py for testing different mixer
  configurations (all-attention, all-gdn, every2nd-gdn, etc.)
- Remove redundant test_loading.py and test_placement_comparison.py
- Remove manual sys.path manipulation and explicit register() calls

The vLLM plugin system uses Python's entry_points mechanism to ensure
model registration happens in all processes (parent and subprocesses).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
These functions were defined but never called. The use_qk_l2norm_in_kernel
parameter in FLA kernels handles L2 normalization internally.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts in modeling_apriel2.py by keeping our structured
import system with USE_VLLM_* flags instead of main's redundant imports.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The apriel2 multimodal config uses Apriel2Config (with vision encoder),
which is not registered with AutoModelForCausalLM. Use
AutoModelForImageTextToText instead, matching the llava config.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tscholak tscholak changed the title [Prototype] Add vLLM Apriel2 model with plugin-based registration [EXTERNAL] Add vLLM Apriel2 model with plugin-based registration Jan 23, 2026
@tscholak tscholak marked this pull request as ready for review January 23, 2026 19:12
@tscholak tscholak requested a review from Copilot January 23, 2026 19:12
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds vLLM-optimized Apriel2 model implementation using vLLM's plugin-based registration mechanism, avoiding the need to maintain a patched vLLM fork. The implementation supports hybrid architectures with attention, mamba, GDN, and KDA mixers, with runtime mixer switching capability for stochastic layers.

Changes:

  • Adds vLLM model implementation with plugin registration via entry points
  • Implements comprehensive test suite comparing vLLM vs Transformers outputs
  • Adds debug flags for kernel comparison between vLLM and FLA implementations

Reviewed changes

Copilot reviewed 9 out of 10 changed files in this pull request and generated 34 comments.

Show a summary per file
File Description
setup.cfg Adds vLLM plugin entry point for automatic model registration
tests/utils/model_configs.py Updates test config with AutoModelForImageTextToText class
fast_llm_external_models/apriel2/vllm/init.py Module initialization exporting Apriel2ForCausalLM
fast_llm_external_models/apriel2/vllm/README.md Documentation for vLLM usage
fast_llm_external_models/apriel2/vllm/config_convertor.py Config converter for nested Apriel2 decoder format
fast_llm_external_models/apriel2/vllm/modeling_apriel2.py Main vLLM model implementation (~2863 lines)
fast_llm_external_models/apriel2/vllm/test_apriel2.py Manual test script for vLLM vs Transformers comparison
fast_llm_external_models/apriel2/modeling_apriel2.py Transformers model with debug flags for kernel comparison
fast_llm_external_models/apriel2/examples/pure_gdn_step*.yaml Example configs for GDN model conversion

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +580 to +626
if USE_VLLM_CONV:
# vLLM expects x as [dim, total_tokens]
# x shape: [batch, dim, seq]
# x_flat[:, t] should equal x[batch_for_t, :, seq_for_t]
# permute to [dim, batch, seq], then reshape to [dim, batch*seq]
x_flat = x.permute(1, 0, 2).reshape(dim, batch_size * seq_len).contiguous()

# Create conv_states buffer: [batch, dim, state_len]
# vLLM requires stride(1) == 1 (dim dimension contiguous)
# Create as [batch, state_len, dim] contiguous, then transpose to get right strides
conv_states = x.new_zeros(batch_size, state_len, dim).transpose(1, 2)

# Create query_start_loc: cumulative sequence lengths
# For batch_size sequences each of length seq_len
query_start_loc = torch.arange(
0, batch_size * seq_len + 1, seq_len,
dtype=torch.int32, device=x.device
)

# has_initial_state: all False (no prior state)
has_initial_state = torch.zeros(batch_size, dtype=torch.bool, device=x.device)

# cache_indices: identity mapping
cache_indices = torch.arange(batch_size, dtype=torch.int32, device=x.device)

# Call vLLM's causal_conv1d_fn
out_flat = causal_conv1d_fn(
x_flat,
self._weight,
self.bias,
conv_states,
query_start_loc,
cache_indices=cache_indices,
has_initial_state=has_initial_state,
activation=self._activation,
)

# Convert back: [dim, total_tokens] -> [batch, dim, seq]
# out_flat shape: [dim, batch*seq]
# reshape to [dim, batch, seq], then permute to [batch, dim, seq]
out = out_flat.reshape(dim, batch_size, seq_len).permute(1, 0, 2)

if return_final_state:
# conv_states was updated in-place by vLLM's implementation
# Return it in the expected format: [batch, dim, state_len]
return out, conv_states
return out
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 580-626 implement vLLM convolution path that is never executed because USE_VLLM_CONV is hardcoded to False. This is 47 lines of dead code that adds unnecessary complexity and maintenance burden. Either remove this code or make the flag configurable if this functionality is actually needed.

Copilot uses AI. Check for mistakes.
Comment on lines +1465 to +1497
if not self._debug_enabled:
return
if t is None:
print(f"[TF-GDN layer={self.layer_idx}] {name}: None")
return
try:
flat = t.flatten()[:8]
vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist())
print(f"[TF-GDN layer={self.layer_idx}] {name}: shape={t.shape}, dtype={t.dtype}, "
f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, "
f"first8=[{vals}]")
except Exception as e:
print(f"[TF-GDN layer={self.layer_idx}] {name}: ERROR accessing tensor: {e}")

def _debug_print(self, msg: str):
if not self._debug_enabled:
return
print(f"[TF-GDN layer={self.layer_idx}] {msg}")

def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int):
"""Debug recurrent state with statistics."""
if not self._debug_state or state is None:
return
try:
flat = state.flatten()
first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist())
print(f"[TF-GDN L{self.layer_idx}] {name} (seq_len={seq_len}): shape={state.shape}, "
f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, "
f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, "
f"first8=[{first8}]")
except Exception as e:
print(f"[TF-GDN L{self.layer_idx}] {name}: ERROR accessing state: {e}")

Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instance-level debug flags (_debug_enabled, _debug_layer, _debug_state, _debug_output) are defined within the class but hardcoded to False. These add dead debug code throughout the forward method. The debug methods (_debug_tensor, _debug_print, _debug_state_stats) check these flags at runtime, creating unnecessary overhead. Consider removing the debug infrastructure or making it opt-in through a proper logging framework.

Suggested change
if not self._debug_enabled:
return
if t is None:
print(f"[TF-GDN layer={self.layer_idx}] {name}: None")
return
try:
flat = t.flatten()[:8]
vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist())
print(f"[TF-GDN layer={self.layer_idx}] {name}: shape={t.shape}, dtype={t.dtype}, "
f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, "
f"first8=[{vals}]")
except Exception as e:
print(f"[TF-GDN layer={self.layer_idx}] {name}: ERROR accessing tensor: {e}")
def _debug_print(self, msg: str):
if not self._debug_enabled:
return
print(f"[TF-GDN layer={self.layer_idx}] {msg}")
def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int):
"""Debug recurrent state with statistics."""
if not self._debug_state or state is None:
return
try:
flat = state.flatten()
first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist())
print(f"[TF-GDN L{self.layer_idx}] {name} (seq_len={seq_len}): shape={state.shape}, "
f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, "
f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, "
f"first8=[{first8}]")
except Exception as e:
print(f"[TF-GDN L{self.layer_idx}] {name}: ERROR accessing state: {e}")
"""
No-op debug hook for tensor inspection.
This method is intentionally empty to avoid runtime overhead from
disabled debug printing while preserving the public interface.
"""
return
def _debug_print(self, msg: str):
"""
No-op debug hook for printing diagnostic messages.
This method is intentionally empty to avoid runtime overhead from
disabled debug printing while preserving the public interface.
"""
return
def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int):
"""
No-op debug hook for recurrent state statistics.
This method is intentionally empty to avoid runtime overhead from
disabled debug printing while preserving the public interface.
"""
return

Copilot uses AI. Check for mistakes.
Comment on lines +2320 to +2363
self.mixers = nn.ModuleDict()
for mixer_index, (name, sub_mixer_config) in enumerate(mixers_config.items()):
sub_mixer_type = sub_mixer_config.get("type", "attention")

if sub_mixer_type not in self.MIXER_REGISTRY:
raise ValueError(f"Unknown sub-mixer type '{sub_mixer_type}' in stochastic mixer")

mixer_class, needs_model_config, needs_spec_config = self.MIXER_REGISTRY[sub_mixer_type]

# Compute virtual layer index using mixer's position index (Falcon H1 style)
# Each sub-mixer gets its own "virtual layer" range: layer_idx + (index+1) * num_layers
# This ensures unique indices even when multiple mixers have the same type
virtual_layer_idx = layer_idx + (mixer_index + 1) * num_layers

# Build prefix with virtual layer index for cache registration
# This only affects static_forward_context registration, not weight loading
virtual_prefix = f"{layers_base}.{virtual_layer_idx}.stochastic_{name}"

# Build kwargs based on what each mixer type needs
kwargs = {
"config": config,
"mixer_config": sub_mixer_config,
"layer_idx": layer_idx, # Keep real layer_idx for any internal use
"cache_config": cache_config,
"quant_config": quant_config,
"prefix": virtual_prefix,
}
if needs_model_config:
kwargs["model_config"] = model_config
if needs_spec_config:
kwargs["speculative_config"] = speculative_config

self.mixers[name] = mixer_class(**kwargs)
logger.debug(
f"Created sub-mixer '{name}' (type={sub_mixer_type}) at virtual layer {virtual_layer_idx} "
f"(real layer {layer_idx}, prefix={virtual_prefix})"
)

self._mixer_names = list(self.mixers.keys())
logger.info(
f"Initialized Apriel2StochasticMixer at layer {layer_idx} with {len(self.mixers)} mixers: "
f"{', '.join(self._mixer_names)} (active={self.active_mixer_name})"
)

Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Apriel2StochasticMixer loads ALL sub-mixer weights into memory even though only one mixer is active at a time (line 2352). For models with multiple mixer types per layer, this could significantly increase memory usage. While the PR description mentions this is for "runtime switching", consider documenting the memory implications or providing a mode to only load the active mixer weights if memory efficiency is a concern.

Copilot uses AI. Check for mistakes.
Comment on lines +2843 to +2863
def _patch_worker_for_placement_switching():
"""Add placement switching methods to the vLLM GPU worker."""
try:
from vllm.v1.worker.gpu_worker import Worker
except ImportError:
return # vLLM not available or different version

if hasattr(Worker, "get_layer_placements"):
return # Already patched

def _get_layer_placements(self) -> dict[int, str]:
return self.get_model().get_layer_placements()

def _set_layer_placements(self, placement: list[str]) -> dict[int, str]:
return self.get_model().set_layer_placements(placement)

Worker.get_layer_placements = _get_layer_placements
Worker.set_layer_placements = _set_layer_placements


_patch_worker_for_placement_switching()
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Monkey-patching the vLLM Worker class (lines 2843-2863) at module import time is fragile and could cause issues with vLLM version changes or multi-process scenarios. The patch is applied globally when the module is imported (line 2863), which happens automatically when vLLM loads the model. Consider:

  1. Adding a version check to ensure compatibility
  2. Providing a way to disable this patching if it causes issues
  3. Documenting this behavior clearly as it modifies vLLM's internal classes

This pattern may break with vLLM updates and makes debugging more difficult.

Copilot uses AI. Check for mistakes.
# TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead).
skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP),
requires_cuda=True,
auto_model_class=transformers.AutoModelForImageTextToText,
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auto_model_class is set to AutoModelForImageTextToText, but based on the PR description and code, this appears to be a text-only model (Apriel2TextConfig). The configuration change suggests this might support multimodal variants, but there's no indication in the PR that image-text functionality is being added. This could cause incorrect model loading behavior.

Suggested change
auto_model_class=transformers.AutoModelForImageTextToText,
auto_model_class=transformers.AutoModelForCausalLM,

Copilot uses AI. Check for mistakes.
Comment on lines +75 to +79
if USE_VLLM_MAMBA_OPS:
raise ImportError("vLLM mamba ops not yet wrapped")
else:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Suggested change
if USE_VLLM_MAMBA_OPS:
raise ImportError("vLLM mamba ops not yet wrapped")
else:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
# vLLM mamba ops are not yet wrapped; use mamba_ssm implementation when available.
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

Copilot uses AI. Check for mistakes.
_debug_final = False # seq_len <= 10
if _debug_final:
# Show LAST token (to match vLLM)
last_token = hidden_states[0, -1, :8]
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
Comment on lines +2647 to +2664
# Debug final norm
batch_size, seq_len = hidden_states.shape[:2]
_debug_final = False # seq_len <= 10
if _debug_final:
# Show LAST token (to match vLLM)
last_token = hidden_states[0, -1, :8]
vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist())
print(f"[TF Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{vals}]")
print(f"[TF Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]")
print(f"[TF Final] norm.variance_epsilon={self.norm.variance_epsilon}")

hidden_states = self.norm(hidden_states)

if _debug_final:
last_token = hidden_states[0, -1, :8]
vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist())
print(f"[TF Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{vals}]")

Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Suggested change
# Debug final norm
batch_size, seq_len = hidden_states.shape[:2]
_debug_final = False # seq_len <= 10
if _debug_final:
# Show LAST token (to match vLLM)
last_token = hidden_states[0, -1, :8]
vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist())
print(f"[TF Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{vals}]")
print(f"[TF Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]")
print(f"[TF Final] norm.variance_epsilon={self.norm.variance_epsilon}")
hidden_states = self.norm(hidden_states)
if _debug_final:
last_token = hidden_states[0, -1, :8]
vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist())
print(f"[TF Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{vals}]")
batch_size, seq_len = hidden_states.shape[:2]
hidden_states = self.norm(hidden_states)

Copilot uses AI. Check for mistakes.
_debug_lm_head = False # seq_len <= 10
if _debug_lm_head:
# Show LAST token's first 8 features (to match vLLM which only passes last token)
last_token_hs = hidden_states[0, -1, :8]
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.

if _debug_lm_head:
# Get last token logits
last_logits = logits[0, -1]
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants