diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index f5ffdafda2..e22b8798eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field +from typing import Optional import torch +import torch.fx from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork @@ -15,6 +17,7 @@ class ConversionContext: requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) weight_refit_map: Dictionary mapping weight names to their corresponding np.array cpu_weights_reference_holder: Dictionary mapping weight names to their corresponding torch.Tensor + current_node: The FX node currently being converted, used by converters that need access to graph-level metadata (e.g. annotations set by lowering passes) """ net: TRTNetwork @@ -25,6 +28,7 @@ class ConversionContext: requires_native_multidevice: bool = False weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict) cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list) + current_node: Optional[torch.fx.Node] = field(default=None) def record_weight(self, name: str, weight: torch.Tensor) -> None: """ diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1b7982f074..d8cff2e317 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -791,6 +791,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: self.ctx.requires_native_multidevice = True _LOGGER.debug(f"{target} requires native multi-device support") + self.ctx.current_node = self._cur_node if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py index af9c2c7519..40ef5ff4db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -1,10 +1,13 @@ import logging +import math from typing import Optional, Tuple, Union import tensorrt as trt from tensorrt import ITensor as TRTTensor +import torch from torch.fx.node import Target from torch_tensorrt._utils import is_tensorrt_version_supported +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -16,6 +19,41 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# FP8 E4M3 max representable magnitude. Softmax output is bounded to [0, 1], +# so 1/448 saturates exactly at 1.0 and is data-independent (no calibration needed). +_FP8_E4M3_MAX = 448.0 + + +def _maybe_set_fp8_softmax( + ctx: ConversionContext, + name: str, + attention_layer: trt.IAttention, +) -> bool: + """Set FP8 softmax normalization quantization on the IAttention layer if the current + node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass. + + Returns True if FP8 normalization was configured (caller must set decomposable=False). + """ + if ctx.current_node is None: + return False + scale_val = ctx.current_node.meta.get("_fp8_softmax_scale") + if scale_val is None: + return False + # Scale dtype must match the IAttention output (= pre-quant Q/K/V) dtype; + # using float32 unconditionally fails TRT compilation on some platforms. + output_dtype = _enums.dtype._from(attention_layer.get_output(0).dtype).to( + torch.dtype + ) + scale_tensor = get_trt_tensor( + ctx, + torch.tensor(scale_val, dtype=output_dtype), + name + "_softmax_fp8_scale", + dtype=output_dtype, + ) + attention_layer.normalization_quantize_to_type = trt.DataType.FP8 + attention_layer.normalization_quantize_scale = scale_tensor + return True + def _normalize_attention_mask_rank( ctx: ConversionContext, @@ -178,6 +216,18 @@ def scaled_dot_product_attention( Returns: TRTTensor: Attention output tensor with shape [batch, heads, seq_len, head_dim] """ + # When FP8 softmax normalization is active (modelopt FP8 MHA pattern) TRT's + # FP8 MHA fusion requires the Q/DQ output to feed IAttention via a single + # same-dtype Mul; any HALF<->FLOAT cast inserted by the default dynamic + # 1/sqrt(D) computation breaks the fusion. Use a static same-dtype scalar + # scale computed from the concrete head_dim. + fp8_norm_active = ( + ctx.current_node is not None + and ctx.current_node.meta.get("_fp8_softmax_scale") is not None + ) + if fp8_norm_active and scale is None and isinstance(query.shape[-1], int): + scale = 1.0 / math.sqrt(query.shape[-1]) + if scale is None: # 1 / math.sqrt(query.size(-1)) q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1) @@ -291,7 +341,8 @@ def scaled_dot_product_attention( if mask_tensor is not None: attention_layer.mask = mask_tensor - attention_layer.decomposable = True + fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer) + attention_layer.decomposable = not fp8_norm attention_output = attention_layer.get_output(0) return attention_output @@ -319,6 +370,13 @@ def scaled_dot_product_flash_attention( Optional[TRTTensor], Optional[TRTTensor], ]: + fp8_norm_active = ( + ctx.current_node is not None + and ctx.current_node.meta.get("_fp8_softmax_scale") is not None + ) + if fp8_norm_active and scale is None and isinstance(query.shape[-1], int): + scale = 1.0 / math.sqrt(query.shape[-1]) + if scale is None: # 1 / math.sqrt(query.size(-1)) q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1) @@ -367,7 +425,8 @@ def scaled_dot_product_flash_attention( ) assert attention_layer is not None, "attention layer is None" - attention_layer.decomposable = True + fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer) + attention_layer.decomposable = not fp8_norm attention_output = attention_layer.get_output(0) return attention_output, None, None, None, 0.0, 0.0, None, None, None @@ -387,6 +446,13 @@ def scaled_dot_product_efficient_attention( is_causal: bool = False, scale: Optional[float] = None, ) -> Tuple[TRTTensor, Optional[TRTTensor], Optional[TRTTensor], Optional[TRTTensor]]: + fp8_norm_active = ( + ctx.current_node is not None + and ctx.current_node.meta.get("_fp8_softmax_scale") is not None + ) + if fp8_norm_active and scale is None and isinstance(query.shape[-1], int): + scale = 1.0 / math.sqrt(query.shape[-1]) + if scale is None: # 1 / math.sqrt(query.size(-1)) q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1) @@ -523,7 +589,8 @@ def scaled_dot_product_efficient_attention( if mask_tensor is not None: attention_layer.mask = mask_tensor - attention_layer.decomposable = True + fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer) + attention_layer.decomposable = not fp8_norm attention_output = attention_layer.get_output(0) return attention_output, None, None, None diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 7b770ab68b..06ef44248a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -10,10 +10,12 @@ trace_intermediate_node_outputs, ) +from .annotate_fp8_sdpa import annotate_fp8_sdpa from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold from .force_causal_efficient_attention import force_causal_efficient_attention from .fuse_prims_broadcast import fuse_prims_broadcast +from .insert_fp8_softmax_qdq import insert_fp8_softmax_qdq from .pass_manager import DynamoPassManager from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach @@ -41,6 +43,8 @@ remove_num_users_is_0_nodes, complex_graph_detection, force_causal_efficient_attention, + annotate_fp8_sdpa, + insert_fp8_softmax_qdq, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/annotate_fp8_sdpa.py b/py/torch_tensorrt/dynamo/lowering/passes/annotate_fp8_sdpa.py new file mode 100644 index 0000000000..257c47974c --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/annotate_fp8_sdpa.py @@ -0,0 +1,76 @@ +import logging + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings + +logger = logging.getLogger(__name__) + +# FP8 E4M3 max. Softmax output is bounded to [0, 1], so 1/448 saturates at 1.0 exactly +# and is data-independent (no calibration required for the softmax output scale). +_FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0 + +_SDPA_TARGETS = { + torch.ops.aten.scaled_dot_product_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, +} + + +def _is_fp8_quantize_op(node: torch.fx.Node) -> bool: + """Return True when node is a tensorrt.quantize_op with FP8 dtype (exponent_bits=4).""" + if node.op != "call_function": + return False + try: + if node.target != torch.ops.tensorrt.quantize_op.default: + return False + except AttributeError: + return False + # args: (input, amax, num_bits, exponent_bits, ...) + args = node.args + return len(args) >= 4 and args[2] == 8 and args[3] == 4 + + +def annotate_fp8_sdpa( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Annotate SDPA nodes whose Q, K, V inputs are all FP8-quantized. + + Detects the pattern emitted by modelopt when an attention module is + registered via ``register_attention_for_kv_quant``, which wraps the + Q, K, V arguments to ``F.scaled_dot_product_attention`` with + ``q_bmm_quantizer``, ``k_bmm_quantizer``, ``v_bmm_quantizer``: + + q_fp8 = quantize_op(q, amax_q, num_bits=8, exponent_bits=4, ...) + k_fp8 = quantize_op(k, amax_k, num_bits=8, exponent_bits=4, ...) + v_fp8 = quantize_op(v, amax_v, num_bits=8, exponent_bits=4, ...) + out = scaled_dot_product_attention(q_fp8, k_fp8, v_fp8, ...) + + When all three inputs match this pattern the pass sets + ``node.meta["_fp8_softmax_scale"] = 1/448`` on the SDPA node so the + attention converter can set ``IAttention.normalization_quantize_to_type + = FP8`` and ``IAttention.normalization_quantize_scale``, which TRT + requires to fuse into the ``_gemm_mha_v2`` FP8 MHA kernel. + """ + changed = False + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in _SDPA_TARGETS: + continue + if len(node.args) < 3: + continue + q_node, k_node, v_node = node.args[0], node.args[1], node.args[2] + if not all( + isinstance(n, torch.fx.Node) and _is_fp8_quantize_op(n) + for n in (q_node, k_node, v_node) + ): + continue + node.meta["_fp8_softmax_scale"] = _FP8_E4M3_SOFTMAX_SCALE + changed = True + logger.debug( + f"Annotated SDPA node {node.name} with FP8 softmax scale " + f"{_FP8_E4M3_SOFTMAX_SCALE} (Q/K/V inputs are FP8-quantized)" + ) + + if changed: + logger.debug("FP8 SDPA softmax annotation complete") + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py b/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py new file mode 100644 index 0000000000..fd3a24beaf --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py @@ -0,0 +1,167 @@ +import logging +from typing import Optional + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings + +from .annotate_fp8_sdpa import _is_fp8_quantize_op + +logger = logging.getLogger(__name__) + +_FP8_E4M3_SOFTMAX_AMAX = 1.0 +_SOFTMAX_TARGETS = { + torch.ops.aten._softmax.default, + torch.ops.aten.softmax.int, +} +_MATMUL_TARGETS = { + torch.ops.aten.matmul, + torch.ops.aten.matmul.default, + torch.ops.aten.dot.default, + torch.ops.aten.mm.default, + torch.ops.aten.mv.default, + torch.ops.aten.bmm.default, +} +# Shape-only ops that may sit between a quantize_op output and a matmul input. +_TRANSPARENT_TARGETS = { + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + torch.ops.aten.reshape.default, + torch.ops.aten._reshape_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.expand.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, +} + + +def _source_is_fp8_quantize(node: Optional[torch.fx.Node]) -> bool: + """Walk through shape-transparent ops to find the producer; True if FP8 quantize_op.""" + seen: set[int] = set() + cur = node + while isinstance(cur, torch.fx.Node) and id(cur) not in seen: + seen.add(id(cur)) + if _is_fp8_quantize_op(cur): + return True + if cur.op == "call_function" and cur.target in _TRANSPARENT_TARGETS: + cur = cur.args[0] if cur.args else None + continue + return False + return False + + +def _single_matmul_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: + """Return the matmul user of ``node`` if it has exactly one and it is a matmul.""" + users = list(node.users) + if len(users) != 1: + return None + user = users[0] + if user.op != "call_function" or user.target not in _MATMUL_TARGETS: + return None + return user + + +def insert_fp8_softmax_qdq( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert an FP8 Q/DQ on softmax output in the decomposed FP8 MHA pattern. + + TRT's Method 2 FP8 MHA fusion requires FP8 Q/DQ on Q, K, V **and** on the + softmax output. modelopt's ``NVFP4_FP8_MHA_CONFIG`` specifies a + ``*softmax_quantizer`` glob, but in practice no SDPA-based modelopt + attention wrapper applies it: the HF ``_QuantAttention`` does not create a + ``softmax_quantizer`` at all (only Q/K/V bmm quantizers), and the diffusers + ``_QuantAttention`` creates one but only invokes it on the ``torch.bmm`` + code path — its ``F.scaled_dot_product_attention`` replacement routes + through a custom ``FP8SDPA`` op that skips softmax quantization. + Consequently, for any model that ends up on the SDPA path used by + ``torch.export``, the exported FX graph has:: + + matmul(q_fp8, k_fp8.T) → mul(1/sqrt(D)) → softmax → matmul(·, v_fp8) + + with no FP8 Q/DQ between ``softmax`` and the second ``matmul``, so TRT + keeps the two matmuls and the softmax as separate kernels instead of + producing ``_gemm_mha_v2``. + + This pass recovers the fusion by inserting a ``tensorrt.quantize_op`` with + ``num_bits=8, exponent_bits=4, amax=1.0`` (→ scale = 1/448) on the softmax + output when the surrounding matmul inputs are FP8-quantized. 1/448 is + data-independent because softmax output ∈ [0, 1]. + + The pass is conservative: it fires only when *all three* of Q, K, V on the + two matmuls trace back to FP8 ``tensorrt.quantize_op`` nodes. If the + graph is not a quantized MHA, nothing changes. + """ + changed = False + amax_buffer_idx = 0 + for node in list(gm.graph.nodes): + if node.op != "call_function" or node.target not in _SOFTMAX_TARGETS: + continue + # The softmax must feed a single matmul (BMM2 = softmax_out @ V). + bmm2 = _single_matmul_user(node) + if bmm2 is None or len(bmm2.args) < 2: + continue + v_source = bmm2.args[1] + if not _source_is_fp8_quantize(v_source): + continue + + # Trace back from softmax to BMM1 through a possible scale/mul/div. + attn_src = node.args[0] if node.args else None + while ( + isinstance(attn_src, torch.fx.Node) + and attn_src.op == "call_function" + and attn_src.target + in { + torch.ops.aten.mul.Tensor, + torch.ops.aten.div.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + } + ): + attn_src = attn_src.args[0] + if not isinstance(attn_src, torch.fx.Node): + continue + if attn_src.op != "call_function" or attn_src.target not in _MATMUL_TARGETS: + continue + if len(attn_src.args) < 2: + continue + q_source, k_source = attn_src.args[0], attn_src.args[1] + if not ( + _source_is_fp8_quantize(q_source) and _source_is_fp8_quantize(k_source) + ): + continue + + # Register a per-insertion amax buffer (1.0). + amax_name = f"_fp8_softmax_qdq_amax_{amax_buffer_idx}" + amax_buffer_idx += 1 + gm.register_buffer( + amax_name, + torch.tensor(_FP8_E4M3_SOFTMAX_AMAX, dtype=torch.float32), + persistent=False, + ) + + with gm.graph.inserting_after(node): + amax_node = gm.graph.create_node( + "get_attr", amax_name, (), {}, name=amax_name + ) + with gm.graph.inserting_after(amax_node): + q_op = gm.graph.create_node( + "call_function", + torch.ops.tensorrt.quantize_op.default, + (node, amax_node, 8, 4, False, False), + {}, + name=f"fp8_softmax_quantize_{amax_buffer_idx - 1}", + ) + + # Re-route downstream matmul to read from the new quantize_op output. + bmm2.replace_input_with(node, q_op) + changed = True + logger.debug( + f"Inserted FP8 softmax Q/DQ after {node.name} " + f"(scale=1/448, pattern=matmul→...→softmax→matmul with FP8 Q/K/V)" + ) + + if changed: + gm.graph.lint() + gm.recompile() + logger.debug("FP8 decomposed-MHA softmax Q/DQ insertion complete") + return gm diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index ec625a59f2..c2681ef47f 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -636,3 +636,300 @@ def calibrate_loop(model): ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=5e-2, atol=5e-2) + + +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_fp8_mha_softmax_quantizer_annotation(ir): + """Regression test for #4200: annotate_fp8_sdpa must tag an SDPA node whose + Q, K, V inputs are all FP8-quantized via ``tensorrt.quantize_op``. + + This matches the FX pattern emitted by modelopt's + ``register_attention_for_kv_quant`` when ``NVFP4_FP8_MHA_CONFIG`` is applied: + the attention module's ``F.scaled_dot_product_attention`` call has its Q, + K, V arguments wrapped by ``q_bmm_quantizer``, ``k_bmm_quantizer``, + ``v_bmm_quantizer`` (all FP8). + + The annotated ``_fp8_softmax_scale = 1/448`` on the SDPA node lets the + attention converter set ``IAttention.normalization_quantize_to_type = FP8`` + and ``IAttention.normalization_quantize_scale`` so TRT can fuse the full + ``_gemm_mha_v2`` FP8 MHA kernel. + + Also verifies that INT8 Q/K/V (exponent_bits=0) or a partially-FP8 input + (one of Q/K/V not quantized) do NOT trigger the annotation. + """ + import torch.fx as fx + from torch_tensorrt.dynamo._settings import CompilationSettings + from torch_tensorrt.dynamo.lowering.passes.annotate_fp8_sdpa import ( + _SDPA_TARGETS, + annotate_fp8_sdpa, + ) + + def _build_sdpa_input_quant_graph( + exponent_bits: int, quantize_v: bool = True + ) -> fx.GraphModule: + """Build FX graph where Q, K, V flow into SDPA through quantize_op nodes.""" + graph = fx.Graph() + q = graph.placeholder("q") + k = graph.placeholder("k") + v = graph.placeholder("v") + amax = graph.placeholder("amax") + q_q = graph.call_function( + torch.ops.tensorrt.quantize_op.default, + args=(q, amax, 8, exponent_bits, False, False), + ) + k_q = graph.call_function( + torch.ops.tensorrt.quantize_op.default, + args=(k, amax, 8, exponent_bits, False, False), + ) + v_q = ( + graph.call_function( + torch.ops.tensorrt.quantize_op.default, + args=(v, amax, 8, exponent_bits, False, False), + ) + if quantize_v + else v + ) + out = graph.call_function( + torch.ops.aten.scaled_dot_product_attention.default, args=(q_q, k_q, v_q) + ) + graph.output(out) + return fx.GraphModule({}, graph) + + settings = CompilationSettings() + + # FP8 Q/K/V inputs (exponent_bits=4): SDPA node must be annotated with 1/448. + gm_fp8 = _build_sdpa_input_quant_graph(exponent_bits=4) + annotate_fp8_sdpa(gm_fp8, settings) + sdpa_nodes = [n for n in gm_fp8.graph.nodes if n.target in _SDPA_TARGETS] + assert sdpa_nodes, "No SDPA node found in graph" + assert all( + "_fp8_softmax_scale" in n.meta for n in sdpa_nodes + ), "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8" + expected_scale = 1.0 / 448.0 + for n in sdpa_nodes: + assert ( + abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12 + ), f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}" + + # INT8 Q/K/V inputs (exponent_bits=0): SDPA node must NOT be annotated. + gm_int8 = _build_sdpa_input_quant_graph(exponent_bits=0) + annotate_fp8_sdpa(gm_int8, settings) + sdpa_int8 = [n for n in gm_int8.graph.nodes if n.target in _SDPA_TARGETS] + assert all( + "_fp8_softmax_scale" not in n.meta for n in sdpa_int8 + ), "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8" + + # Only Q and K are FP8-quantized, V is raw: SDPA must NOT be annotated. + gm_partial = _build_sdpa_input_quant_graph(exponent_bits=4, quantize_v=False) + annotate_fp8_sdpa(gm_partial, settings) + sdpa_partial = [n for n in gm_partial.graph.nodes if n.target in _SDPA_TARGETS] + assert all( + "_fp8_softmax_scale" not in n.meta for n in sdpa_partial + ), "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8" + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@pytest.mark.unit +def test_fp8_mha_fused_kernel(ir): + """Regression test for #4200: FP8 MHA with FP8 Q/K/V inputs must produce a + fused ``_gemm_mha_v2`` MHA kernel with normalization_quantize_to_type set. + + Hand-constructs the FX pattern that a future modelopt PyTorch-backend + version will emit for FP8 MHA (mirrors PR NVIDIA/Model-Optimizer#1289): + + quantize_op(Q) ─┐ + quantize_op(K) ─┤─ scaled_dot_product_attention + quantize_op(V) ─┘ + + Built directly via ``torch.ops.tensorrt.quantize_op`` so we do not depend + on modelopt actually supporting this pattern in its PyTorch backend today — + if/when it does, torch-tensorrt will compile that graph to the fused kernel. + + Verifies: + 1. Engine inspector shows a layer name containing ``mha`` (i.e. + ``_gemm_mha_v2``), confirming the FP8 MHA fusion triggered. + 2. Numerics match PyTorch reference SDPA within FP8 tolerance + (cosine_similarity > 0.99). + + D=64 meets TRT's head_dim >= 32 requirement for the + normalization_quantize FP8 kernel. + """ + import json + + import torch_tensorrt + + import tensorrt as trt + + B, H, S, D = 1, 2, 32, 64 + torch.manual_seed(0) + + class FP8MHAModel(torch.nn.Module): + """Mirror of what a modelopt FP8 MHA PyTorch export will look like: + tensorrt.quantize_op on Q, K, V feeding F.scaled_dot_product_attention.""" + + def __init__(self, amax_val: float = 6.0): + super().__init__() + self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32)) + self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32)) + self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32)) + + def forward(self, q, k, v): + q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False) + k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False) + v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False) + return torch.nn.functional.scaled_dot_product_attention(q_fp8, k_fp8, v_fp8) + + q = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + k = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + v = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + + model = FP8MHAModel().eval().cuda() + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + exp_program = torch.export.export(model, (q, k, v), strict=False) + serialized_engine = ( + torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine( + exp_program, + inputs=[q, k, v], + use_explicit_typing=True, + min_block_size=1, + ) + ) + + runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + engine = runtime.deserialize_cuda_engine(serialized_engine) + inspector = engine.create_engine_inspector() + engine_json = json.loads( + inspector.get_engine_information(trt.LayerInformationFormat.JSON) + ) + layers = engine_json.get("Layers", []) + layer_names = [ + layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers + ] + assert any("mha" in name.lower() for name in layer_names), ( + f"No fused MHA kernel found in compiled engine. Expected a layer " + f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + " + f"normalization_quantize_to_type into a single MHA kernel. " + f"Layer names present: {layer_names}" + ) + + # Numerical sanity: FP8-quantized MHA should agree with PyTorch SDPA. + compiled = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=[q, k, v], + use_explicit_typing=True, + min_block_size=1, + ) + with torch.no_grad(): + trt_out = compiled(q, k, v) + cos = torch.nn.functional.cosine_similarity( + ref_out.flatten().float().unsqueeze(0), + trt_out.flatten().float().unsqueeze(0), + ).item() + assert ( + cos > 0.99 + ), f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}" + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@pytest.mark.unit +def test_fp8_mha_fused_kernel_decomposed(ir): + """Regression test for the decomposed FP8 MHA path (TRT Method 2). + + With ``decompose_attention=True`` the SDPA op is expanded into explicit + ``matmul → mul(1/sqrt(D)) → softmax → matmul`` primitives (no + ``IAttention``). TRT fuses this into ``_gemm_mha_v2`` only when FP8 + Q/DQ is present on Q, K, V **and** on the softmax output. + + modelopt's HF ``_QuantAttention.softmax_quantizer`` is only applied in + the Triton FA path, so the standard FX graph lacks the softmax Q/DQ. + The ``insert_fp8_softmax_qdq`` lowering pass adds it back (scale = 1/448). + This test constructs the pattern manually and compiles with + ``decompose_attention=True`` to verify the fusion still triggers. + """ + import json + + import torch_tensorrt + + import tensorrt as trt + + B, H, S, D = 1, 2, 32, 64 + torch.manual_seed(0) + + class FP8MHAModel(torch.nn.Module): + def __init__(self, amax_val: float = 6.0): + super().__init__() + self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32)) + self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32)) + self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32)) + + def forward(self, q, k, v): + q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False) + k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False) + v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False) + return torch.nn.functional.scaled_dot_product_attention(q_fp8, k_fp8, v_fp8) + + q = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + k = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + v = torch.randn(B, H, S, D, dtype=torch.float16).cuda() + + model = FP8MHAModel().eval().cuda() + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + exp_program = torch.export.export(model, (q, k, v), strict=False) + serialized_engine = ( + torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine( + exp_program, + inputs=[q, k, v], + use_explicit_typing=True, + min_block_size=1, + decompose_attention=True, + ) + ) + + runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + engine = runtime.deserialize_cuda_engine(serialized_engine) + inspector = engine.create_engine_inspector() + engine_json = json.loads( + inspector.get_engine_information(trt.LayerInformationFormat.JSON) + ) + layers = engine_json.get("Layers", []) + layer_names = [ + layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers + ] + assert any("mha" in name.lower() for name in layer_names), ( + f"No fused MHA kernel found on decomposed path. Expected a layer " + f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + " + f"softmax-output Q/DQ into _gemm_mha_v2 on Method 2 path. " + f"Layer names: {layer_names}" + ) + + # Numerical sanity + compiled = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=[q, k, v], + use_explicit_typing=True, + min_block_size=1, + decompose_attention=True, + ) + with torch.no_grad(): + trt_out = compiled(q, k, v) + cos = torch.nn.functional.cosine_similarity( + ref_out.flatten().float().unsqueeze(0), + trt_out.flatten().float().unsqueeze(0), + ).item() + assert ( + cos > 0.99 + ), f"Decomposed FP8 MHA output deviates from PyTorch reference: cos={cos}"