Skip to content

↔ [Converter] IAttention converter bypasses TRT native IAttention layer due to HuggingFace causal mask (attn_bias) #4129

@chohk88

Description

@chohk88

Summary

When using the --backend iattention option in tools/llm/run_llm.py, the TensorRT native IAttention layer (ctx.net.add_attention()) is never used. Instead, the scaled_dot_product_efficient_attention converter falls back to manual matmul+softmax+matmul decomposition for every attention layer. This results in no speedup (or even slight slowdown) compared to PyTorch eager for autoregressive generation.

Root Cause

HuggingFace's SDPA attention implementation always generates a causal mask tensor and passes it as attn_bias to aten._scaled_dot_product_efficient_attention. The converter in py/torch_tensorrt/dynamo/conversion/impl/attention.py has two code paths:

def scaled_dot_product_efficient_attention(..., attn_bias=None, ...):
    if attn_bias is not None:
        # SLOW PATH: manual matmul → add bias → softmax → matmul
        attn_weight = matmul(scaled_query, key^T)
        attn_weight = add(attn_weight, attn_bias)
        attn_weight = softmax(attn_weight)
        out = matmul(attn_weight, value)
        return out, None, None, None
    else:
        # FAST PATH: TRT native IAttention layer
        attention_layer = ctx.net.add_attention(scaled_query, key, value, SOFTMAX, is_causal)
        return attention_layer.get_output(0), None, None, None

Since HuggingFace always passes attn_bias (a causal mask tensor), the fast path is never taken.

Evidence (debug log analysis)

Running with --debug on Qwen2.5-0.5B-Instruct (24 layers):

$ grep -c "attn_bias_add" debug_stderr.log
24                    # ← manual decomposition used for ALL 24 layers

$ grep -c "add_attention\|IAttention" debug_stderr.log
0                     # ← TRT IAttention layer NEVER used

Benchmark Results

Tested on NVIDIA A100 80GB PCIe, FP16, ISL=2048, OSL=128, Batch=1:

Model PyTorch (ms) sdpa no cache (ms) iattention (ms) sdpa static_v1 (ms) plugin (ms)
Qwen2.5-0.5B 4751 3271 5421 1238 421
Qwen3-0.6B 6875 4031 6792 1708 569
Llama-3.2-1B 7053 5466 8283 1379 465
  • iattention is slower than sdpa no cache despite both lacking KV cache
  • sdpa no cache is faster because the SDPA lowering pass replaces attn_bias with is_causal=True and dynamically generates the mask, which is more efficient
  • iattention keeps the HF-provided attn_bias tensor, taking the slow manual decomposition path

Why sdpa backend doesn't have this problem

The SDPA lowering pass in tools/llm/torchtrt_ext/register_sdpa.py transforms the graph:

BEFORE (HF output):
  aten._scaled_dot_product_efficient_attention(Q, K, V, attn_bias=<mask_tensor>, is_causal=False)

AFTER (lowering pass):
  F.scaled_dot_product_attention(Q, K, V, attn_mask=None, is_causal=True)

The lowering pass discards attn_bias and sets is_causal=True, so the custom SDPA converter generates the causal mask dynamically inside TensorRT — avoiding the large mask tensor overhead entirely.

Suggested Fix

Option A: In the efficient_attention converter, detect when attn_bias is a standard causal mask and convert it to is_causal=True + add_attention():

if attn_bias is not None and _is_causal_mask(attn_bias):
    # Use IAttention with is_causal=True
    attention_layer = ctx.net.add_attention(scaled_query, key, value, SOFTMAX, is_causal=True)
else:
    # Keep manual decomposition for non-standard masks
    ...

Option B: Use the IAttention layer's mask input directly:

attention_layer = ctx.net.add_attention(scaled_query, key, value, SOFTMAX, is_causal=False)
attention_layer.mask = attn_bias  # pass mask to IAttention layer

Option C: Add a lowering pass (similar to register_sdpa) that normalizes HF's attn_bias to is_causal=True before the converter runs.

Environment

Related

Metadata

Metadata

Assignees

Labels

component: convertersIssues re: Specific op convertersfeature requestNew feature or requeststory: LLM & Generative AILarge language models (GPT2, Llama, Mistral, Qwen), diffusion models (FLUX, SD), VLMs, MoE, attentio

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions