From b79fd5f203fc65a5e42c12019dcb5ac67892a0ac Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 19 May 2026 12:24:19 +0000 Subject: [PATCH 1/2] feat: support in-place plugins --- examples/dynamo/aot_plugin.py | 238 +++++++++++++++++- .../dynamo/conversion/_ConversionContext.py | 2 + .../dynamo/conversion/_TRTInterpreter.py | 14 ++ .../conversion/plugins/_generate_plugin.py | 51 +++- .../plugins/_generate_plugin_converter.py | 120 +++++++++ .../lowering/passes/_aten_lowering_pass.py | 6 + .../passes/unfunctionalize_qdp_inplace.py | 153 +++++++++++ .../runtime/_PythonTorchTensorRTModule.py | 37 ++- .../test_automatic_plugin_inplace.py | 106 ++++++++ 9 files changed, 719 insertions(+), 8 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py create mode 100644 tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py index 234b2b4204..0aabd71bda 100644 --- a/examples/dynamo/aot_plugin.py +++ b/examples/dynamo/aot_plugin.py @@ -34,10 +34,11 @@ import tensorrt as trt import tensorrt.plugin as trtp import torch -import torch_tensorrt import triton import triton.language as tl +import torch_tensorrt + trt_logger = trt.Logger(trt.Logger.VERBOSE) @@ -51,7 +52,14 @@ @triton.jit -def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): +def add_one_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + # Arg order matters for the AOT path: TRT launches the embedded PTX with + # arguments in (input_ptrs, output_ptrs, extra_args) order — inputs first, + # then outputs, then anything from ``extra_args`` in ``@trtp.aot_impl``. + # If this kernel declared ``(x_ptr, n_elements, y_ptr, ...)`` then TRT + # would feed the output pointer into ``n_elements`` and ``n_elements`` + # into ``y_ptr`` at launch, which is a wild pointer dereference (engine + # builds fine, ``enqueueV3`` returns -1 and the process segfaults). pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -61,6 +69,23 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offsets, output, mask=mask) +@triton.jit +def add_one_inplace_kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + # Distinct kernel for the aliased-I/O variant. The plugin descriptor + # declares its output as ``X.aliased()`` — at runtime TRT passes a + # *single* pointer for the aliased I/O pair (the shared buffer), not two. + # If we re-used ``add_one_kernel`` here, TRT would supply: pointer, + # n_elements, padding... and the kernel's ``y_ptr`` slot would absorb + # ``n_elements`` while ``n_elements`` would read the padding zero — the + # mask would be all-false and the kernel would do nothing. + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(x_ptr + offsets, x + 1, mask=mask) + + # %% # Step 2: Register the PyTorch op # ----------------------------------------- @@ -77,7 +102,7 @@ def add_one(X: torch.Tensor) -> torch.Tensor: Y = torch.empty_like(X) BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) - add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE) + add_one_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE) return Y @@ -148,8 +173,8 @@ def add_plugin_aot_impl( fn=add_one_kernel, signature={ "x_ptr": f"*{type_str}", - "n_elements": "i32", "y_ptr": f"*{type_str}", + "n_elements": "i32", }, constexprs={ "BLOCK_SIZE": block_size, @@ -167,10 +192,24 @@ def add_plugin_aot_impl( launch_params.shared_mem = compiled_kernel.metadata.shared # bytes of shared mem # ``extra_args`` are scalar arguments appended to the kernel's argument list at - # launch. Here ``n_elements`` is passed as a 32-bit symbolic integer so TRT + # launch. ``n_elements`` is passed as a 32-bit symbolic integer so TRT # evaluates it from the actual tensor size at runtime. - extra_args = trtp.SymIntExprs(1) + # + # Triton >= 3.x always emits two trailing ``.param .u64 .ptr`` slots in + # the compiled PTX for ``global_scratch`` and ``profile_scratch`` — even + # when their sizes (``compiled_kernel.metadata.global_scratch_size`` / + # ``profile_scratch_size``) are 0. Triton's own runtime allocates + # zero-sized scratch buffers and passes those pointers at launch; TRT's + # AOT plugin path doesn't know about them and would otherwise leave the + # two trailing slots filled with stale register state — symptom: + # ``Failed to enqueue status -1`` and a segfault on the first call. + # We pad ``extra_args`` with four ``SymInt32(0)`` (two per u64 slot) so + # the kernel sees null pointers for both scratch params; since their + # sizes are 0 the kernel never dereferences them. + extra_args = trtp.SymIntExprs(1 + 4) extra_args[0] = trtp.SymInt32(N) + for _i in range(1, 5): + extra_args[_i] = trtp.SymInt32(0) return ( compiled_kernel.metadata.name, # kernel function name in PTX @@ -201,6 +240,123 @@ def add_plugin_aot_impl( ) +# %% +# In-place variant: aliased plugin I/O +# ----------------------------------------- +# +# This second registration shows the same kernel exposed as an *in-place* plugin — +# the engine mutates the input buffer directly instead of allocating a separate +# output. Useful for KV-cache updates and any pattern where only a small slice of +# a large state changes per call. +# +# Three things change vs. ``my::add_one`` above: +# +# 1. ``mutates_args=("X",)`` on the torch op. This is the load-bearing signal — +# it tells the QDP descriptor in ``_generate_plugin.py`` that input ``X`` is +# a candidate for aliasing, and it also tells PyTorch's autograd and +# functionalization machinery that the op has side effects on ``X``. +# +# 2. The registered fake returns ``X`` by identity (``return X``). Combined with +# the schema's ``mutates_args``, this is what makes the descriptor emit +# ``X.aliased()`` (see ``_generate_plugin._generic_plugin_desc``) instead of +# building a fresh output ``TensorDesc``. +# +# 3. The descriptor itself uses ``X.aliased()``. ``aliased()`` returns a +# ``TensorDesc`` that shares its data buffer with ``X`` — TRT will route the +# same pointer to both the input and output binding at runtime. +# +# The eager torch impl has to mutate ``X`` itself (so the semantics match what +# the engine will do) and return ``X.clone()``. ``torch.library`` forbids +# returning an input by identity from a custom op, hence the clone. +# +# Note on the AOT kernel: we re-use ``add_one_kernel`` unchanged. Its signature +# takes two pointers (``x_ptr``, ``y_ptr``). With aliased I/O declared, TRT +# passes the same buffer for both — the kernel reads from ``x_ptr`` and writes +# to ``y_ptr``, which is the same memory, so the effect is in-place. + + +@torch.library.custom_op("my::add_one_inplace", mutates_args=("X",)) # type: ignore[misc] +def add_one_inplace(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + add_one_inplace_kernel[grid](X, X.numel(), BLOCK_SIZE=BLOCK_SIZE) + # Must not return X by identity — torch.library's no-alias constraint + # rejects that. The TRT path doesn't observe this clone (aliasing is + # declared at the descriptor level), it's purely for the eager impl. + return X.clone() + + +@torch.library.register_fake("my::add_one_inplace") +def _(X: torch.Tensor) -> torch.Tensor: + # Identity return is the secondary signal the descriptor uses to detect + # aliasing. Combined with ``mutates_args=("X",)`` above, this is what + # makes ``_generic_plugin_desc`` emit ``X.aliased()`` for the output. + return X + + +@trtp.register("my::add_one_inplace") +def add_one_inplace_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: + # ``aliased()`` is the QDP API that declares output-shares-storage-with-input. + # Engine build will fail with "PreviewFeature::kALIASED_PLUGIN_IO_10_03 not + # enabled" unless the build config enables that preview feature; the + # converter wires this on for you when it sees a non-empty aliased_map. + return X.aliased() + + +@trtp.aot_impl("my::add_one_inplace") +def add_one_inplace_aot_impl( + X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int +) -> Tuple[ + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs +]: + type_str = "fp32" if X.dtype == trt.float32 else "fp16" + + block_size = 256 + # Single-pointer signature — see the comment on ``add_one_inplace_kernel``. + src = triton.compiler.ASTSource( + fn=add_one_inplace_kernel, + signature={ + "x_ptr": f"*{type_str}", + "n_elements": "i32", + }, + constexprs={ + "BLOCK_SIZE": block_size, + }, + ) + compiled_kernel = triton.compile(src) + + N = X.shape_expr.numel() + launch_params = trtp.KernelLaunchParams() + launch_params.grid_x = trtp.cdiv(N, block_size) + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 + launch_params.shared_mem = compiled_kernel.metadata.shared + + # See the matching note on the non-in-place ``add_plugin_aot_impl``: + # Triton 3.x emits two trailing ``.param .u64 .ptr`` slots for the + # global/profile scratch buffers, and TRT's AOT path needs them zeroed + # explicitly via ``extra_args`` so the kernel doesn't read stale state. + extra_args = trtp.SymIntExprs(1 + 4) + extra_args[0] = trtp.SymInt32(N) + for _i in range(1, 5): + extra_args[_i] = trtp.SymInt32(0) + + return ( + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], + launch_params, + extra_args, + ) + + +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "my::add_one_inplace", + supports_dynamic_shapes=False, + requires_output_allocator=False, + use_aot_if_available=True, +) + + # %% # Step 6: Compile and Run # ----------------------------------------- @@ -219,6 +375,15 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: return res +class MyInplaceModel(torch.nn.Module): + """Drives the in-place plugin. The op mutates ``X`` in place; the returned + tensor carries the post-mutation value (a clone, only to satisfy + torch.library's no-alias rule).""" + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.ops.my.add_one_inplace.default(X) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -246,3 +411,64 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: assert torch.allclose(res, my_model(m)), "Results do not match!" print("Inference successful!") + + # %% + # In-place plugin demo + # --------------------- + # + # The standard "compile once, run many times" comparison pattern doesn't + # work for an in-place op because each call mutates the input — running + # eager and TRT on the same buffer double-applies the mutation. We work + # off a base tensor and clone for each call instead. + # + # Three things to verify, beyond "it ran": + # 1. The compiled module contains a TRT engine (not a PyTorch fallback — + # a regression here would silently pass on value because the eager + # kernel mutates the input the same way). + # 2. The engine's return value matches the expected post-mutation tensor. + # 3. The user's input buffer was mutated in place by the engine — the + # actual reason to use aliased plugin I/O in the first place. + + print("\nIn-place plugin demo:") + inplace_model = MyInplaceModel().to("cuda").eval() + base = torch.full((64, 64), 2, device="cuda", dtype=torch.float) + expected_post = base + 1 + + model_trt_inplace = torch_tensorrt.compile( + inplace_model, + inputs=[base.clone()], + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in model_trt_inplace.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + assert engine_submodules, ( + "Expected a TRT engine submodule for the in-place plugin path, but the" + " compiled module is pure PyTorch — check that the un-functionalize" + " lowering pass restored the mutating op before partitioning. Graph:\n" + f"{model_trt_inplace.graph}" + ) + print(f" TRT engine submodule(s) present: {len(engine_submodules)}") + + with torch.no_grad(): + trt_input = base.clone() + trt_out = model_trt_inplace(trt_input) + assert torch.allclose(trt_out, expected_post), "TRT output mismatch" + assert torch.allclose(trt_input, expected_post), ( + "Engine did not mutate the input buffer — aliased plugin I/O is" + " not active. Check that PreviewFeature.ALIASED_PLUGIN_IO_10_03" + " was enabled and that the descriptor emitted X.aliased()." + ) + + print(" Output matches expected post-mutation value.") + print(" Input buffer was mutated in place by the TRT engine.") + print("In-place inference successful!") diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index f5ffdafda2..3de546404a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import torch + from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork @@ -23,6 +24,7 @@ class ConversionContext: ) requires_output_allocator: bool = False requires_native_multidevice: bool = False + requires_aliased_plugin_io: bool = False weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict) cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e1f4d8bafb..9b7bdd1196 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -24,6 +24,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes + from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._enums import dtype from torch_tensorrt._features import needs_refit @@ -310,6 +311,19 @@ def _populate_trt_builder_config( if self.compilation_settings.enable_weight_streaming: builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING) + if self.ctx.requires_aliased_plugin_io: + aliased_io_feature = getattr( + trt.PreviewFeature, "ALIASED_PLUGIN_IO_10_03", None + ) + if aliased_io_feature is None: + raise RuntimeError( + "An in-place QDP plugin declared aliased I/O, but this TensorRT" + " version does not expose PreviewFeature.ALIASED_PLUGIN_IO_10_03." + " TensorRT 10.3+ is required for aliased plugin I/O." + ) + builder_config.set_preview_feature(aliased_io_feature, True) + _LOGGER.info("Enabling preview feature ALIASED_PLUGIN_IO_10_03") + if is_tensorrt_version_supported("10.8"): TilingOptimizationLevel = { "none": trt.TilingOptimizationLevel.NONE, diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index bf087d01cc..4f82954dd7 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -72,6 +72,25 @@ def _generate_plugin(plugin_name: str) -> None: # retrieve the corresponding torch operation using the passed in string torch_op = getattr(getattr(torch.ops, namespace), name) + # Positional indices of tensor inputs marked as mutated in the schema + # (Tensor(aN!) ... -> ...). These are candidates for in-place QDP outputs + # via `TensorDesc.aliased()` — the actual alias map (output_idx -> input_idx) + # is determined at fake-run time by checking output-to-input identity, since + # torch.library does not stamp aliasing onto the return alias_info. + _schema = torch_op._schemas[""] + _tensor_arg_positions = [ + i + for i, a in enumerate(_schema.arguments) + if a.type.isSubtypeOf(torch._C.TensorType.get()) + ] + _mutated_tensor_arg_positions = { + _tensor_arg_positions.index(i) + for i, a in enumerate(_schema.arguments) + if a.type.isSubtypeOf(torch._C.TensorType.get()) + and a.alias_info is not None + and a.alias_info.is_write + } + # helper function that generates the required signature based on the torch operation def generate_signature( torch_op: Callable[[Any], Any], @@ -187,6 +206,18 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. # a tuple; single-output ops return a bare Tensor. outputs_list = list(output) if isinstance(output, (tuple, list)) else [output] + # Build output_idx -> input_idx alias map. An output is aliased to a + # mutated input when both (a) the input is declared mutates_args in the + # schema and (b) the fake kernel returned that exact tensor by identity. + # The schema gate prevents accidental aliasing when a fake kernel + # incidentally returns an input (e.g. an identity op without mutation). + alias_map: dict[int, int] = {} + for out_idx, fake_out in enumerate(outputs_list): + for in_idx in _mutated_tensor_arg_positions: + if in_idx < len(fake_args) and fake_out is fake_args[in_idx]: + alias_map[out_idx] = in_idx + break + input_node_expr = list( itertools.chain.from_iterable( [sym.node.expr for sym in syms_arg] for syms_arg in syms_args @@ -208,6 +239,13 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. out_descs = [] for out_idx, fake_out in enumerate(outputs_list): + # In-place output: tell TRT this output shares its buffer with the + # mutated input. `aliased()` carries the input's shape/dtype so we + # don't rebuild the descriptor from the symbolic expressions. + if out_idx in alias_map: + out_descs.append(tensor_args[alias_map[out_idx]].aliased()) + continue + shape_calc_fns: list[Any] = [None] * fake_out.ndim for i in range(fake_out.ndim): out_dim = fake_out.shape[i] @@ -274,12 +312,23 @@ def _generic_plugin_impl( dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + # Outputs declared `.aliased()` share storage with their input — the + # mutation already lands in the destination buffer, so copying onto it + # would either be a no-op self-copy or, worse, clobber the in-place + # result if the eager kernel returned a fresh tensor. + aliased_outputs = { + i for i, o in enumerate(outputs) if o.get_aliased() is not None + } + stream = torch.cuda.ExternalStream(stream) with torch.cuda.stream(stream): out_tensors = torch_op(*in_tensors, *non_tensor_args, **torch_kwargs) if isinstance(out_tensors, torch.Tensor): out_tensors = (out_tensors,) - [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + for i, (d, o) in enumerate(zip(dest_tensors, out_tensors)): + if i in aliased_outputs: + continue + d.copy_(o) plugin_impl_func = f""" {plugin_impl_signature} diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index a2c39e6f7e..dfec5ca2b0 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target @@ -44,6 +45,102 @@ def _coerce_plugin_attr_for_qdp(value: Any, attr_annotation: Any) -> Any: return value +_PYTHON_SCALAR_TO_NUMPY_DTYPE = { + float: np.float64, + int: np.int64, + bool: np.bool_, +} + + +def _patch_trtp_scalar_attr_roundtrip() -> None: + """Patch ``tensorrt.plugin``'s scalar-attribute reconstruction. + + ``_TemplatePluginCreator.create_plugin`` rebuilds a Python scalar attr via + ``attr_type_annot(f.data)`` (e.g. ``float(f.data)``). The serialize path + stores Python scalars as ``np.array([value])`` (1-d size-1) and TRT's + C++ PluginField construction further promotes any input ndarray to 1-d, + so ``f.data`` is always 1-d on read. ``float(np.array([0.2]))`` raises + "only 0-dimensional arrays can be converted to Python scalars", which + surfaces as a converter failure on any plugin that declares scalar attrs + via ``@trtp.register`` (e.g. ``b: float, a: int``). + + PluginField data is also immutable on the Python side, so we can't fix + the shape before the unpatched code reads it. Instead: route around the + broken ``attr_type_annot(f.data)`` line by temporarily promoting the + scalar annotation to ``npt.NDArray[]`` — that + branches into ``.astype()``, which handles 1-d arrays fine — then + restore the annotation and unwrap the resulting 1-d arrays back to the + Python scalar types the descriptor/impl expects. Applied once, no-op if + the upstream bindings are ever fixed. + """ + try: + from tensorrt_bindings.plugin import _lib as _trtp_lib + from tensorrt_bindings.plugin._utils import _is_numpy_array + except ImportError: + return + + creator_cls = getattr(_trtp_lib, "_TemplatePluginCreator", None) + if creator_cls is None or getattr(creator_cls, "_torch_trt_scalar_patched", False): + return + + orig_create_plugin = creator_cls.create_plugin + + def _patched_create_plugin( + self: Any, + name: str, + namespace: str, + fc: Any, + phase: Any, + qpcr: Any = None, + ) -> Any: + from tensorrt_bindings.plugin._lib import QDP_REGISTRY + + desc = QDP_REGISTRY.get(f"{namespace}::{name}") + if desc is None: + return orig_create_plugin(self, name, namespace, fc, phase, qpcr) + + scalar_attrs: dict[str, type] = {} + for f in fc: + ann = desc.input_attrs.get(f.name) + if ann is None or _is_numpy_array(ann): + continue + if not isinstance(ann, type): + continue + if ann in _PYTHON_SCALAR_TO_NUMPY_DTYPE: + scalar_attrs[f.name] = ann + + if not scalar_attrs: + return orig_create_plugin(self, name, namespace, fc, phase, qpcr) + + saved_annotations = {n: desc.input_attrs[n] for n in scalar_attrs} + for n, ann in scalar_attrs.items(): + # mypy reads ``npt.NDArray[X]`` as a static type form, but X here + # is a runtime value pulled from the dtype lookup table. + desc.input_attrs[n] = npt.NDArray[_PYTHON_SCALAR_TO_NUMPY_DTYPE[ann]] # type: ignore[valid-type] + try: + plg = orig_create_plugin(self, name, namespace, fc, phase, qpcr) + finally: + for n, ann in saved_annotations.items(): + desc.input_attrs[n] = ann + + # Unwrap the 1-d size-1 ndarrays the promoted path produced back to + # the Python scalar types the descriptor's annotations declared, so + # the user's `@trtp.register` / `@trtp.impl` bodies receive what + # they signed up for. + for n, ann in scalar_attrs.items(): + value = plg.attrs.get(n) + if isinstance(value, np.ndarray) and value.size == 1: + plg.attrs[n] = ann(value.reshape(()).item()) + + return plg + + creator_cls.create_plugin = _patched_create_plugin + creator_cls._torch_trt_scalar_patched = True + + +_patch_trtp_scalar_attr_roundtrip() + + def _is_numpy_attr_annotation(annotation: Any) -> bool: return annotation is np.ndarray or typing.get_origin(annotation) is np.ndarray @@ -159,6 +256,29 @@ def custom_kernel_converter( f"Adding generated plugin for {namespace}::{name} to tensorrt network" ) layer.name = f"[{target}]-[{name}]" + + # The QDP plugin populates `aliased_map` (output_idx -> input_idx, with + # -1 meaning no alias) during `add_plugin` when TRT invokes the + # descriptor's `get_output_data_types`. Any non-negative entry means + # the engine build needs the aliased plugin I/O preview feature + # enabled. `plugin(*args)` itself is just a creation closure — the + # populated `aliased_map` lives on the layer's `plugin` attribute. + # JIT plugins: `layer.plugin` returns the Python `_TemplateJITPlugin` + # instance, whose `aliased_map` (output_idx -> input_idx, -1 means + # none) is populated when TRT invokes the descriptor during + # `add_plugin`. AOT plugins: `layer.plugin` returns a bare + # `trt.IPluginV3` C++ wrapper that doesn't expose the Python + # attribute. We read the JIT map when we can, and otherwise enable + # the aliased-I/O preview feature unconditionally for AOT — it's + # dormant in TRT when no plugin actually declares `.aliased()`, so + # this only adds the flag (which is needed when aliasing *is* + # declared) without changing semantics when it isn't. + layer_plugin = getattr(layer, "plugin", None) + aliased_map = getattr(layer_plugin, "aliased_map", None) + if aliased_map and any(v != -1 for v in aliased_map.values()): + ctx.requires_aliased_plugin_io = True + elif use_aot_plugin: + ctx.requires_aliased_plugin_io = True # Single-output ops expect a bare ITensor; multi-output ops expect a # tuple so the downstream ``getitem`` converter can unpack it. num_outputs = len(torch_schema.returns) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b25219bc82..dc7a955065 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Sequence, Union import torch + from torch_tensorrt._utils import is_tegra_platform from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes._FakeTensorUpdater import FakeTensorUpdater @@ -23,6 +24,7 @@ from .replace_fused_rms_norm import replace_fused_rms_norm from .replace_max_pool_with_indices import replace_max_pool_with_indices from .rule_based_autocast import rule_based_autocast +from .unfunctionalize_qdp_inplace import unfunctionalize_qdp_inplace pre_lowering_pass_list = [ remove_detach, @@ -31,6 +33,10 @@ ] post_lowering_pass_list = [ + # Must run before remove_num_users_is_0_nodes and any pass that walks the + # converter registry, so the underlying mutating op is restored ahead of + # partitioning and the QDP `.aliased()` descriptor can take effect. + unfunctionalize_qdp_inplace, replace_fused_rms_norm, remove_input_alias_fixing_clones, constant_fold, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py new file mode 100644 index 0000000000..6c206391c0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py @@ -0,0 +1,153 @@ +"""Reverse ``run_decompositions``' functionalization of mutating custom ops +that have a registered Dynamo converter (QDP in-place plugins). + +``run_decompositions`` rewrites every call ``my_inplace_op(x, ...)`` (declared +with ``mutates_args=("x",)``) into:: + + %af = auto_functionalized_v2(my_inplace_op, _x_base_index=N, _all_bases=[%x], ...) + %getitem_0 = af[0] # the op's actual return + %getitem_k = af[k] # for k in 1..len(_all_bases): the post-mutation base + %copy_ = aten.copy_.default(%x, %getitem_k) # propagate the mutation back + +The result is correct for PyTorch eager, but our converter is registered against +the original mutating overload and the partitioner sees the HOP wrapper as +unsupported — so the whole subgraph is bailed out. This pass restores the +direct mutating call when the underlying op has a converter, lets the QDP +``.aliased()`` descriptor declared in `_generate_plugin.py` actually reach the +engine builder, and drops the now-redundant copy_/getitem nodes. +""" + +import logging +from typing import Any, Dict, List + +import torch + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + +logger = logging.getLogger(__name__) + + +def _auto_functionalized_targets() -> List[Any]: + # Both HOPs exist across torch versions; v2 is what current export emits. + targets: List[Any] = [] + higher_order = getattr(torch.ops, "higher_order", None) + if higher_order is None: + return targets + for name in ("auto_functionalized_v2", "auto_functionalized"): + op = getattr(higher_order, name, None) + if op is not None: + targets.append(op) + return targets + + +def _reconstruct_op_args(op_overload: Any, node_kwargs: Dict[str, Any]) -> List[Any]: + """Rebuild positional args for a direct call to ``op_overload`` from the + HOP kwargs. The QDP-generated converter reads tensor inputs positionally + (``args[0 : len(tensor_inputs)]``), so we must place args in schema order + rather than passing them by name. + + ``auto_functionalized_v2`` packs mutated tensor arguments via + ``__base_index: N`` plus ``_all_bases: [t0, t1, ...]`` instead of + inlining them. Non-mutated args are passed by name as-is. + """ + bases = node_kwargs.get("_all_bases", []) + out: List[Any] = [] + schema = op_overload._schema + for arg in schema.arguments: + base_key = f"_{arg.name}_base_index" + if base_key in node_kwargs: + out.append(bases[node_kwargs[base_key]]) + elif arg.name in node_kwargs: + out.append(node_kwargs[arg.name]) + elif arg.has_default_value(): + out.append(arg.default_value) + else: + raise RuntimeError( + f"auto_functionalized_v2 missing argument '{arg.name}' for" + f" {op_overload} (no value and no default)" + ) + return out + + +def unfunctionalize_qdp_inplace( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + af_targets = _auto_functionalized_targets() + if not af_targets: + return gm + + modified = False + for node in list(gm.graph.nodes): + if node.op != "call_function" or node.target not in af_targets: + continue + if not node.args or not hasattr(node.args[0], "_schema"): + continue + op_overload = node.args[0] + if op_overload not in DYNAMO_CONVERTERS: + # No converter for the underlying op — leave it functionalized. + continue + + op_args = _reconstruct_op_args(op_overload, dict(node.kwargs)) + bases = node.kwargs.get("_all_bases", []) + + with gm.graph.inserting_before(node): + new_call = gm.graph.call_function(op_overload, args=tuple(op_args)) + # Propagate the op's own meta["val"] (first element of the HOP's + # tuple) so downstream shape extraction has what it needs. + hop_val = node.meta.get("val") + if isinstance(hop_val, tuple) and len(hop_val) >= 1: + new_call.meta["val"] = hop_val[0] + if "tensor_meta" in node.meta: + new_call.meta["tensor_meta"] = node.meta["tensor_meta"] + + # Rewrite getitem users of the HOP. Both tuple slots — index 0 (the + # op's actual return) and indices >= 1 (post-mutation bases) — get + # routed to the direct op call. Routing index-k users to the original + # base placeholder would point downstream nodes at the pre-mutation + # FX value (and would also strand `new_call` as dead code), so we + # anchor everything to the new call instead. + getitem_users = [u for u in list(node.users) if u.target is _operator_getitem()] + for user in getitem_users: + user.replace_all_uses_with(new_call) + gm.graph.erase_node(user) + + # If the HOP has any remaining users (unusual — would mean someone + # consumed the tuple directly), bail rather than leave a dangling + # reference to the now-stale HOP node. + if list(node.users): + raise RuntimeError( + f"auto_functionalized_v2 node {node.name} has non-getitem users" + f" {list(node.users)}; cannot un-functionalize safely." + ) + gm.graph.erase_node(node) + modified = True + + if not modified: + return gm + + # `copy_(base, op_return)` was synthesized by functionalization to write + # the mutation back through the base placeholder. With the direct mutating + # call restored, the buffer is already mutated by the kernel — keeping the + # copy_ would leave an unsupported mutating op blocking partitioning. Drop + # it and route its users straight to the op return. + for node in list(gm.graph.nodes): + if ( + node.op == "call_function" + and node.target is torch.ops.aten.copy_.default + and len(node.args) >= 2 + and node.args[0].op == "placeholder" + ): + node.replace_all_uses_with(node.args[1]) + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + logger.debug(f"Un-functionalized QDP in-place ops:\n{gm.graph}") + return gm + + +def _operator_getitem() -> Any: + import operator + + return operator.getitem diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3c454933bb..98e9ce36d3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -7,8 +7,9 @@ import torch import torch.distributed as dist -import torch_tensorrt from torch.nn import Module + +import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -436,6 +437,24 @@ def setup_engine(self) -> None: for input_name in self.input_names } + # Engines containing QDP plugins with aliased I/O (declared via + # ``TensorDesc.aliased()`` and built with the + # ``ALIASED_PLUGIN_IO_10_03`` preview feature) expect the output + # binding for an aliased output and its corresponding input binding + # to point at the same buffer at runtime. ``create_output_tensors`` + # otherwise allocates a fresh output buffer, which both breaks the + # in-place semantics (user's input is not mutated) and produces + # unspecified output values. Query the engine once at setup so the + # forward pass can reuse the user's input tensor for any aliased + # output. ``get_aliased_input_tensor`` returns the input tensor's + # name when aliased, ``None`` otherwise. + self.aliased_output_to_input = {} + if hasattr(self.engine, "get_aliased_input_tensor"): + for output_name in self.output_names: + aliased_input_name = self.engine.get_aliased_input_tensor(output_name) + if aliased_input_name: + self.aliased_output_to_input[output_name] = aliased_input_name + def _setup_runtime_config(self) -> None: """Create a RuntimeConfig with runtime cache for TensorRT-RTX. @@ -727,6 +746,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) outputs = self.create_output_tensors() + # Aliased outputs (QDP in-place plugins) must share the + # input's storage rather than the freshly-allocated + # buffer ``create_output_tensors`` produced. Swap the + # corresponding entry in ``outputs`` to point at the + # user's input tensor so the in-place mutation surfaces + # back to the caller and so ``set_tensor_address`` below + # binds the same buffer to both bindings. + if self.aliased_output_to_input: + input_by_name = dict(zip(self.input_names, contiguous_inputs)) + for o, output_name in enumerate(self.output_names): + aliased_input_name = self.aliased_output_to_input.get( + output_name + ) + if aliased_input_name is not None: + outputs[o] = input_by_name[aliased_input_name] + for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py new file mode 100644 index 0000000000..03dfa1f795 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py @@ -0,0 +1,106 @@ +import platform +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op("torchtrt_ex::add_one_inplace", mutates_args=("X",)) # type: ignore[misc] +def add_one_inplace(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + # torch.library forbids returning an input directly; clone to satisfy the + # no-alias constraint while still letting the registered fake (which + # returns X by identity) signal aliasing for the TRT plugin descriptor. + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePlugin(unittest.TestCase): + """The standard DispatchTestCase.run_test passes the same input tensor to + eager and TRT, which double-applies the mutation for in-place ops. Use a + bespoke flow with cloned inputs and verify both that the output matches the + expected post-mutation value AND that the input buffer was mutated in + place.""" + + @parameterized.expand( + [ + ((64, 64), torch.float), + ((128, 32), torch.float), + ] + ) + def test_add_one_inplace(self, input_shape, dtype): + class Model(nn.Module): + def forward(self, x): + return torch.ops.torchtrt_ex.add_one_inplace.default(x) + + base = torch.randn(input_shape, device="cuda", dtype=dtype) + + eager_input = base.clone() + eager_out = Model()(eager_input) + expected_post = base + 1 + torch.testing.assert_close(eager_input, expected_post, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(eager_out, expected_post, rtol=1e-5, atol=1e-5) + + trt_input = base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[trt_input.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + # Guard against regressing to PyTorch fallback: if the un-functionalize + # pass stops restoring the mutating op, the partitioner finds 0 + # supported ops and the "compiled" module is just PyTorch eager — the + # test would still pass on value because the eager op mutates the + # input itself, masking the real failure. + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + trt_out = compiled(trt_input) + torch.testing.assert_close(trt_out, expected_post, rtol=1e-5, atol=1e-5) + # The whole point of aliased plugin I/O is that the input buffer is + # mutated in place by the engine. If the engine had allocated a fresh + # output buffer, `trt_input` would still hold the pre-call values. + torch.testing.assert_close(trt_input, expected_post, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() From 526e5558f6edf004554fcd00d0d80e7e064df740 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 20 May 2026 14:17:15 +0000 Subject: [PATCH 2/2] update --- examples/dynamo/aot_plugin.py | 134 +++------------ .../dynamo/conversion/plugins/__init__.py | 1 + .../dynamo/conversion/plugins/_aot_utils.py | 37 +++++ .../conversion/plugins/_generate_plugin.py | 57 ++++--- .../plugins/_generate_plugin_converter.py | 71 +++----- .../passes/unfunctionalize_qdp_inplace.py | 157 ++++++++++-------- .../runtime/_PythonTorchTensorRTModule.py | 46 ++--- .../test_automatic_plugin_inplace.py | 25 +-- .../test_automatic_plugin_inplace_consumed.py | 87 ++++++++++ .../test_automatic_plugin_inplace_dynamic.py | 90 ++++++++++ .../test_automatic_plugin_inplace_multi.py | 106 ++++++++++++ 11 files changed, 515 insertions(+), 296 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py create mode 100644 tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py create mode 100644 tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py create mode 100644 tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py index 0aabd71bda..810d6f4d98 100644 --- a/examples/dynamo/aot_plugin.py +++ b/examples/dynamo/aot_plugin.py @@ -53,13 +53,8 @@ @triton.jit def add_one_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - # Arg order matters for the AOT path: TRT launches the embedded PTX with - # arguments in (input_ptrs, output_ptrs, extra_args) order — inputs first, - # then outputs, then anything from ``extra_args`` in ``@trtp.aot_impl``. - # If this kernel declared ``(x_ptr, n_elements, y_ptr, ...)`` then TRT - # would feed the output pointer into ``n_elements`` and ``n_elements`` - # into ``y_ptr`` at launch, which is a wild pointer dereference (engine - # builds fine, ``enqueueV3`` returns -1 and the process segfaults). + # AOT path requires (inputs, outputs, extra_args) order — swapping any + # two slots feeds the wrong value into the kernel and segfaults. pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -71,13 +66,8 @@ def add_one_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): @triton.jit def add_one_inplace_kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - # Distinct kernel for the aliased-I/O variant. The plugin descriptor - # declares its output as ``X.aliased()`` — at runtime TRT passes a - # *single* pointer for the aliased I/O pair (the shared buffer), not two. - # If we re-used ``add_one_kernel`` here, TRT would supply: pointer, - # n_elements, padding... and the kernel's ``y_ptr`` slot would absorb - # ``n_elements`` while ``n_elements`` would read the padding zero — the - # mask would be all-false and the kernel would do nothing. + # Aliased-I/O variant: TRT routes a single pointer for the shared + # input/output buffer, so the kernel takes one pointer, not two. pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -191,29 +181,13 @@ def add_plugin_aot_impl( launch_params.block_x = compiled_kernel.metadata.num_warps * 32 # threads per block launch_params.shared_mem = compiled_kernel.metadata.shared # bytes of shared mem - # ``extra_args`` are scalar arguments appended to the kernel's argument list at - # launch. ``n_elements`` is passed as a 32-bit symbolic integer so TRT - # evaluates it from the actual tensor size at runtime. - # - # Triton >= 3.x always emits two trailing ``.param .u64 .ptr`` slots in - # the compiled PTX for ``global_scratch`` and ``profile_scratch`` — even - # when their sizes (``compiled_kernel.metadata.global_scratch_size`` / - # ``profile_scratch_size``) are 0. Triton's own runtime allocates - # zero-sized scratch buffers and passes those pointers at launch; TRT's - # AOT plugin path doesn't know about them and would otherwise leave the - # two trailing slots filled with stale register state — symptom: - # ``Failed to enqueue status -1`` and a segfault on the first call. - # We pad ``extra_args`` with four ``SymInt32(0)`` (two per u64 slot) so - # the kernel sees null pointers for both scratch params; since their - # sizes are 0 the kernel never dereferences them. - extra_args = trtp.SymIntExprs(1 + 4) - extra_args[0] = trtp.SymInt32(N) - for _i in range(1, 5): - extra_args[_i] = trtp.SymInt32(0) + extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args( + [trtp.SymInt32(N)], compiled_kernel=compiled_kernel + ) return ( - compiled_kernel.metadata.name, # kernel function name in PTX - compiled_kernel.asm["ptx"], # PTX source — embedded in TRT engine + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], launch_params, extra_args, ) @@ -244,35 +218,15 @@ def add_plugin_aot_impl( # In-place variant: aliased plugin I/O # ----------------------------------------- # -# This second registration shows the same kernel exposed as an *in-place* plugin — -# the engine mutates the input buffer directly instead of allocating a separate -# output. Useful for KV-cache updates and any pattern where only a small slice of -# a large state changes per call. -# -# Three things change vs. ``my::add_one`` above: -# -# 1. ``mutates_args=("X",)`` on the torch op. This is the load-bearing signal — -# it tells the QDP descriptor in ``_generate_plugin.py`` that input ``X`` is -# a candidate for aliasing, and it also tells PyTorch's autograd and -# functionalization machinery that the op has side effects on ``X``. -# -# 2. The registered fake returns ``X`` by identity (``return X``). Combined with -# the schema's ``mutates_args``, this is what makes the descriptor emit -# ``X.aliased()`` (see ``_generate_plugin._generic_plugin_desc``) instead of -# building a fresh output ``TensorDesc``. +# Same kernel exposed as an in-place plugin: the engine mutates the input +# buffer directly via QDP ``TensorDesc.aliased()`` instead of allocating a +# separate output. Useful for KV-cache updates and similar patterns. # -# 3. The descriptor itself uses ``X.aliased()``. ``aliased()`` returns a -# ``TensorDesc`` that shares its data buffer with ``X`` — TRT will route the -# same pointer to both the input and output binding at runtime. -# -# The eager torch impl has to mutate ``X`` itself (so the semantics match what -# the engine will do) and return ``X.clone()``. ``torch.library`` forbids -# returning an input by identity from a custom op, hence the clone. -# -# Note on the AOT kernel: we re-use ``add_one_kernel`` unchanged. Its signature -# takes two pointers (``x_ptr``, ``y_ptr``). With aliased I/O declared, TRT -# passes the same buffer for both — the kernel reads from ``x_ptr`` and writes -# to ``y_ptr``, which is the same memory, so the effect is in-place. +# Two signals together declare aliasing to the framework: +# * ``mutates_args=("X",)`` on the torch op +# * the registered fake returns ``X`` by identity +# The eager impl must mutate ``X`` itself and return a clone — ``torch.library`` +# rejects returning an input by identity. @torch.library.custom_op("my::add_one_inplace", mutates_args=("X",)) # type: ignore[misc] @@ -281,26 +235,16 @@ def add_one_inplace(X: torch.Tensor) -> torch.Tensor: BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) add_one_inplace_kernel[grid](X, X.numel(), BLOCK_SIZE=BLOCK_SIZE) - # Must not return X by identity — torch.library's no-alias constraint - # rejects that. The TRT path doesn't observe this clone (aliasing is - # declared at the descriptor level), it's purely for the eager impl. return X.clone() @torch.library.register_fake("my::add_one_inplace") def _(X: torch.Tensor) -> torch.Tensor: - # Identity return is the secondary signal the descriptor uses to detect - # aliasing. Combined with ``mutates_args=("X",)`` above, this is what - # makes ``_generic_plugin_desc`` emit ``X.aliased()`` for the output. return X @trtp.register("my::add_one_inplace") def add_one_inplace_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: - # ``aliased()`` is the QDP API that declares output-shares-storage-with-input. - # Engine build will fail with "PreviewFeature::kALIASED_PLUGIN_IO_10_03 not - # enabled" unless the build config enables that preview feature; the - # converter wires this on for you when it sees a non-empty aliased_map. return X.aliased() @@ -313,7 +257,6 @@ def add_one_inplace_aot_impl( type_str = "fp32" if X.dtype == trt.float32 else "fp16" block_size = 256 - # Single-pointer signature — see the comment on ``add_one_inplace_kernel``. src = triton.compiler.ASTSource( fn=add_one_inplace_kernel, signature={ @@ -332,14 +275,9 @@ def add_one_inplace_aot_impl( launch_params.block_x = compiled_kernel.metadata.num_warps * 32 launch_params.shared_mem = compiled_kernel.metadata.shared - # See the matching note on the non-in-place ``add_plugin_aot_impl``: - # Triton 3.x emits two trailing ``.param .u64 .ptr`` slots for the - # global/profile scratch buffers, and TRT's AOT path needs them zeroed - # explicitly via ``extra_args`` so the kernel doesn't read stale state. - extra_args = trtp.SymIntExprs(1 + 4) - extra_args[0] = trtp.SymInt32(N) - for _i in range(1, 5): - extra_args[_i] = trtp.SymInt32(0) + extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args( + [trtp.SymInt32(N)], compiled_kernel=compiled_kernel + ) return ( compiled_kernel.metadata.name, @@ -376,10 +314,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: class MyInplaceModel(torch.nn.Module): - """Drives the in-place plugin. The op mutates ``X`` in place; the returned - tensor carries the post-mutation value (a clone, only to satisfy - torch.library's no-alias rule).""" - def forward(self, X: torch.Tensor) -> torch.Tensor: return torch.ops.my.add_one_inplace.default(X) @@ -416,18 +350,8 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: # In-place plugin demo # --------------------- # - # The standard "compile once, run many times" comparison pattern doesn't - # work for an in-place op because each call mutates the input — running - # eager and TRT on the same buffer double-applies the mutation. We work - # off a base tensor and clone for each call instead. - # - # Three things to verify, beyond "it ran": - # 1. The compiled module contains a TRT engine (not a PyTorch fallback — - # a regression here would silently pass on value because the eager - # kernel mutates the input the same way). - # 2. The engine's return value matches the expected post-mutation tensor. - # 3. The user's input buffer was mutated in place by the engine — the - # actual reason to use aliased plugin I/O in the first place. + # In-place ops mutate their input, so eager and TRT must run on separate + # cloned buffers; otherwise each comparison double-applies the mutation. print("\nIn-place plugin demo:") inplace_model = MyInplaceModel().to("cuda").eval() @@ -452,10 +376,8 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) ] assert engine_submodules, ( - "Expected a TRT engine submodule for the in-place plugin path, but the" - " compiled module is pure PyTorch — check that the un-functionalize" - " lowering pass restored the mutating op before partitioning. Graph:\n" - f"{model_trt_inplace.graph}" + "Expected a TRT engine submodule for the in-place plugin path; got a" + f" pure-PyTorch fallback. Graph:\n{model_trt_inplace.graph}" ) print(f" TRT engine submodule(s) present: {len(engine_submodules)}") @@ -463,11 +385,9 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: trt_input = base.clone() trt_out = model_trt_inplace(trt_input) assert torch.allclose(trt_out, expected_post), "TRT output mismatch" - assert torch.allclose(trt_input, expected_post), ( - "Engine did not mutate the input buffer — aliased plugin I/O is" - " not active. Check that PreviewFeature.ALIASED_PLUGIN_IO_10_03" - " was enabled and that the descriptor emitted X.aliased()." - ) + assert torch.allclose( + trt_input, expected_post + ), "Engine did not mutate the input buffer — aliased plugin I/O is not active." print(" Output matches expected post-mutation value.") print(" Input buffer was mutated in place by the TRT engine.") diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py index fc5e973560..e71b3e0044 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py @@ -1,3 +1,4 @@ +from torch_tensorrt.dynamo.conversion.plugins._aot_utils import make_aot_extra_args from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import ( diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py b/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py new file mode 100644 index 0000000000..03ec4fd9d1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py @@ -0,0 +1,37 @@ +"""Helpers for writing AOT QDP plugins backed by Triton kernels.""" + +from typing import Any, Sequence + + +def _has_triton_scratch_params(compiled_kernel: Any) -> bool: + md = getattr(compiled_kernel, "metadata", None) + if md is None: + return False + return hasattr(md, "global_scratch_size") and hasattr(md, "profile_scratch_size") + + +def make_aot_extra_args( + user_args: Sequence[Any], + *, + compiled_kernel: Any = None, +) -> Any: + """Build a ``trtp.SymIntExprs`` for an AOT plugin's ``extra_args`` return. + + When ``compiled_kernel`` is a Triton-compiled kernel, four trailing + ``SymInt32(0)`` are appended to cover the two ``.param .u64 .ptr`` slots + (``global_scratch``, ``profile_scratch``) that Triton >= 3.x always emits + in PTX even when their sizes are zero. TRT's AOT plugin path does not + plumb those slots through, so without padding ``enqueueV3`` reads stale + register state for them and segfaults on the first call. + """ + import tensorrt.plugin as trtp + + pad = 4 if _has_triton_scratch_params(compiled_kernel) else 0 + total = len(user_args) + pad + out = trtp.SymIntExprs(total) + for i, arg in enumerate(user_args): + out[i] = arg + zero = trtp.SymInt32(0) + for i in range(len(user_args), total): + out[i] = zero + return out diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index 4f82954dd7..9475eab86e 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -71,31 +71,34 @@ def _generate_plugin(plugin_name: str) -> None: # retrieve the corresponding torch operation using the passed in string torch_op = getattr(getattr(torch.ops, namespace), name) + default_schema = torch_op.default._schema # Positional indices of tensor inputs marked as mutated in the schema - # (Tensor(aN!) ... -> ...). These are candidates for in-place QDP outputs - # via `TensorDesc.aliased()` — the actual alias map (output_idx -> input_idx) - # is determined at fake-run time by checking output-to-input identity, since - # torch.library does not stamp aliasing onto the return alias_info. - _schema = torch_op._schemas[""] + # (Tensor(aN!) ... -> ...) — candidates for in-place QDP outputs via + # `TensorDesc.aliased()`. The actual alias map (output_idx -> input_idx) + # is decided at fake-run time by checking output-to-input identity. _tensor_arg_positions = [ i - for i, a in enumerate(_schema.arguments) + for i, a in enumerate(default_schema.arguments) if a.type.isSubtypeOf(torch._C.TensorType.get()) ] _mutated_tensor_arg_positions = { _tensor_arg_positions.index(i) - for i, a in enumerate(_schema.arguments) + for i, a in enumerate(default_schema.arguments) if a.type.isSubtypeOf(torch._C.TensorType.get()) and a.alias_info is not None and a.alias_info.is_write } + # Cached frozenset of aliased output indices, populated by the descriptor + # on first build-time call so the runtime impl skips the per-call probe. + _aliased_indices_cache: list[Any] = [None] + # helper function that generates the required signature based on the torch operation def generate_signature( torch_op: Callable[[Any], Any], ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: - schema = torch_op._schemas[""] + schema = torch_op.default._schema arg_list = [] @@ -202,21 +205,19 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. output = torch_op(*fake_args, *non_tensor_args, **torch_kwargs) - # Normalize to a list of fake outputs. Multi-output torch ops return - # a tuple; single-output ops return a bare Tensor. outputs_list = list(output) if isinstance(output, (tuple, list)) else [output] - # Build output_idx -> input_idx alias map. An output is aliased to a - # mutated input when both (a) the input is declared mutates_args in the - # schema and (b) the fake kernel returned that exact tensor by identity. - # The schema gate prevents accidental aliasing when a fake kernel - # incidentally returns an input (e.g. an identity op without mutation). + # output_idx -> input_idx alias map: an output aliases a mutated input + # iff the schema marks the input as mutated AND the fake returns that + # tensor by identity. The schema gate prevents accidental aliasing on + # incidental identity returns from non-mutating ops. alias_map: dict[int, int] = {} for out_idx, fake_out in enumerate(outputs_list): for in_idx in _mutated_tensor_arg_positions: if in_idx < len(fake_args) and fake_out is fake_args[in_idx]: alias_map[out_idx] = in_idx break + _aliased_indices_cache[0] = frozenset(alias_map.keys()) input_node_expr = list( itertools.chain.from_iterable( @@ -239,10 +240,14 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. out_descs = [] for out_idx, fake_out in enumerate(outputs_list): - # In-place output: tell TRT this output shares its buffer with the - # mutated input. `aliased()` carries the input's shape/dtype so we - # don't rebuild the descriptor from the symbolic expressions. if out_idx in alias_map: + # Aliased output shares its buffer with a mutated input; + # `aliased()` carries the input's shape/dtype. + # Limitation: TRT preview-feature ALIASED_PLUGIN_IO_10_03 + # inserts a defensive copy that breaks aliasing when a + # multi-output plugin's aliased output is consumed by + # another TRT layer in the same engine. Single-output + # plugins (the common KV-cache pattern) work end-to-end. out_descs.append(tensor_args[alias_map[out_idx]].aliased()) continue @@ -312,13 +317,15 @@ def _generic_plugin_impl( dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] - # Outputs declared `.aliased()` share storage with their input — the - # mutation already lands in the destination buffer, so copying onto it - # would either be a no-op self-copy or, worse, clobber the in-place - # result if the eager kernel returned a fresh tensor. - aliased_outputs = { - i for i, o in enumerate(outputs) if o.get_aliased() is not None - } + # Skip copy_ for aliased outputs: storage is shared with the input + # the eager op already mutated, so copying would either no-op against + # itself or clobber the in-place result when the eager op returns a + # fresh tensor (as torch.library forces it to do). + aliased_outputs = _aliased_indices_cache[0] + if aliased_outputs is None: + aliased_outputs = frozenset( + i for i, o in enumerate(outputs) if o.get_aliased() is not None + ) stream = torch.cuda.ExternalStream(stream) with torch.cuda.stream(stream): diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index dfec5ca2b0..d43296a267 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -53,25 +53,12 @@ def _coerce_plugin_attr_for_qdp(value: Any, attr_annotation: Any) -> Any: def _patch_trtp_scalar_attr_roundtrip() -> None: - """Patch ``tensorrt.plugin``'s scalar-attribute reconstruction. - - ``_TemplatePluginCreator.create_plugin`` rebuilds a Python scalar attr via - ``attr_type_annot(f.data)`` (e.g. ``float(f.data)``). The serialize path - stores Python scalars as ``np.array([value])`` (1-d size-1) and TRT's - C++ PluginField construction further promotes any input ndarray to 1-d, - so ``f.data`` is always 1-d on read. ``float(np.array([0.2]))`` raises - "only 0-dimensional arrays can be converted to Python scalars", which - surfaces as a converter failure on any plugin that declares scalar attrs - via ``@trtp.register`` (e.g. ``b: float, a: int``). - - PluginField data is also immutable on the Python side, so we can't fix - the shape before the unpatched code reads it. Instead: route around the - broken ``attr_type_annot(f.data)`` line by temporarily promoting the - scalar annotation to ``npt.NDArray[]`` — that - branches into ``.astype()``, which handles 1-d arrays fine — then - restore the annotation and unwrap the resulting 1-d arrays back to the - Python scalar types the descriptor/impl expects. Applied once, no-op if - the upstream bindings are ever fixed. + """Work around ``_TemplatePluginCreator.create_plugin`` calling + ``float(np.array([v]))`` on scalar plugin attrs and crashing because + ``f.data`` is always 1-d after the C++ PluginField round-trip. We + temporarily promote the annotation to ``npt.NDArray[dtype]``, run the + upstream path, then unwrap back to the declared Python scalar type. + Idempotent; no-op once upstream ships a fix. """ try: from tensorrt_bindings.plugin import _lib as _trtp_lib @@ -114,8 +101,6 @@ def _patched_create_plugin( saved_annotations = {n: desc.input_attrs[n] for n in scalar_attrs} for n, ann in scalar_attrs.items(): - # mypy reads ``npt.NDArray[X]`` as a static type form, but X here - # is a runtime value pulled from the dtype lookup table. desc.input_attrs[n] = npt.NDArray[_PYTHON_SCALAR_TO_NUMPY_DTYPE[ann]] # type: ignore[valid-type] try: plg = orig_create_plugin(self, name, namespace, fc, phase, qpcr) @@ -123,10 +108,6 @@ def _patched_create_plugin( for n, ann in saved_annotations.items(): desc.input_attrs[n] = ann - # Unwrap the 1-d size-1 ndarrays the promoted path produced back to - # the Python scalar types the descriptor's annotations declared, so - # the user's `@trtp.register` / `@trtp.impl` bodies receive what - # they signed up for. for n, ann in scalar_attrs.items(): value = plg.attrs.get(n) if isinstance(value, np.ndarray) and value.size == 1: @@ -138,9 +119,6 @@ def _patched_create_plugin( creator_cls._torch_trt_scalar_patched = True -_patch_trtp_scalar_attr_roundtrip() - - def _is_numpy_attr_annotation(annotation: Any) -> bool: return annotation is np.ndarray or typing.get_origin(annotation) is np.ndarray @@ -188,6 +166,8 @@ def _generate_plugin_converter( ) from tensorrt.plugin._lib import QDP_REGISTRY + _patch_trtp_scalar_attr_roundtrip() + torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" overload_name = overload_str if overload else "default" @@ -196,7 +176,14 @@ def _generate_plugin_converter( f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}," " unable to generate converter" ) - torch_schema = torch_target._schemas[overload_str] + torch_schema = torch_overload._schema + + schema_declares_mutation = any( + arg.alias_info is not None + and arg.alias_info.is_write + and arg.type.isSubtypeOf(torch._C.TensorType.get()) + for arg in torch_schema.arguments + ) use_aot_plugin = use_aot_if_available @@ -257,30 +244,18 @@ def custom_kernel_converter( ) layer.name = f"[{target}]-[{name}]" - # The QDP plugin populates `aliased_map` (output_idx -> input_idx, with - # -1 meaning no alias) during `add_plugin` when TRT invokes the - # descriptor's `get_output_data_types`. Any non-negative entry means - # the engine build needs the aliased plugin I/O preview feature - # enabled. `plugin(*args)` itself is just a creation closure — the - # populated `aliased_map` lives on the layer's `plugin` attribute. - # JIT plugins: `layer.plugin` returns the Python `_TemplateJITPlugin` - # instance, whose `aliased_map` (output_idx -> input_idx, -1 means - # none) is populated when TRT invokes the descriptor during - # `add_plugin`. AOT plugins: `layer.plugin` returns a bare - # `trt.IPluginV3` C++ wrapper that doesn't expose the Python - # attribute. We read the JIT map when we can, and otherwise enable - # the aliased-I/O preview feature unconditionally for AOT — it's - # dormant in TRT when no plugin actually declares `.aliased()`, so - # this only adds the flag (which is needed when aliasing *is* - # declared) without changing semantics when it isn't. + # JIT path: layer.plugin is the Python `_TemplateJITPlugin` whose + # `aliased_map` is populated by TRT during `add_plugin`. + # AOT path: layer.plugin is a C++ wrapper that does not expose the + # map, so fall back to the op schema's mutation declaration — the + # same signal `_generate_plugin` uses to emit `.aliased()`. layer_plugin = getattr(layer, "plugin", None) aliased_map = getattr(layer_plugin, "aliased_map", None) if aliased_map and any(v != -1 for v in aliased_map.values()): ctx.requires_aliased_plugin_io = True - elif use_aot_plugin: + elif schema_declares_mutation: ctx.requires_aliased_plugin_io = True - # Single-output ops expect a bare ITensor; multi-output ops expect a - # tuple so the downstream ``getitem`` converter can unpack it. + num_outputs = len(torch_schema.returns) if num_outputs == 1: return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py index 6c206391c0..605802687b 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py @@ -1,23 +1,21 @@ """Reverse ``run_decompositions``' functionalization of mutating custom ops that have a registered Dynamo converter (QDP in-place plugins). -``run_decompositions`` rewrites every call ``my_inplace_op(x, ...)`` (declared -with ``mutates_args=("x",)``) into:: +``run_decompositions`` rewrites ``my_inplace_op(x, ...)`` into:: %af = auto_functionalized_v2(my_inplace_op, _x_base_index=N, _all_bases=[%x], ...) - %getitem_0 = af[0] # the op's actual return - %getitem_k = af[k] # for k in 1..len(_all_bases): the post-mutation base - %copy_ = aten.copy_.default(%x, %getitem_k) # propagate the mutation back - -The result is correct for PyTorch eager, but our converter is registered against -the original mutating overload and the partitioner sees the HOP wrapper as -unsupported — so the whole subgraph is bailed out. This pass restores the -direct mutating call when the underlying op has a converter, lets the QDP -``.aliased()`` descriptor declared in `_generate_plugin.py` actually reach the -engine builder, and drops the now-redundant copy_/getitem nodes. + %g0 = af[0] # the op's actual return + %gk = af[k] # post-mutation base (k = 1..len(_all_bases)) + %copy_ = aten.copy_.default(%x, %gk) + +Correct in eager, but our converter is registered against the original +mutating overload — the partitioner sees the HOP wrapper as unsupported and +bails. This pass restores the direct mutating call when a converter exists +and drops the synthesized copy_ nodes. """ import logging +import operator from typing import Any, Dict, List import torch @@ -29,7 +27,6 @@ def _auto_functionalized_targets() -> List[Any]: - # Both HOPs exist across torch versions; v2 is what current export emits. targets: List[Any] = [] higher_order = getattr(torch.ops, "higher_order", None) if higher_order is None: @@ -42,14 +39,13 @@ def _auto_functionalized_targets() -> List[Any]: def _reconstruct_op_args(op_overload: Any, node_kwargs: Dict[str, Any]) -> List[Any]: - """Rebuild positional args for a direct call to ``op_overload`` from the - HOP kwargs. The QDP-generated converter reads tensor inputs positionally - (``args[0 : len(tensor_inputs)]``), so we must place args in schema order - rather than passing them by name. - - ``auto_functionalized_v2`` packs mutated tensor arguments via - ``__base_index: N`` plus ``_all_bases: [t0, t1, ...]`` instead of - inlining them. Non-mutated args are passed by name as-is. + """Rebuild positional args for a direct call to ``op_overload``. + + The QDP-generated converter reads tensor inputs positionally + (``args[0 : len(tensor_inputs)]``), so args must be in schema order. + ``auto_functionalized_v2`` packs mutated tensors via + ``__base_index: N`` + ``_all_bases: [t0, ...]``; non-mutated args + are passed by name as-is. """ bases = node_kwargs.get("_all_bases", []) out: List[Any] = [] @@ -77,77 +73,96 @@ def unfunctionalize_qdp_inplace( if not af_targets: return gm - modified = False + converter_check_cache: Dict[Any, bool] = {} + aten_copy_ = torch.ops.aten.copy_.default + hops_to_rewrite: List[Any] = [] + copy_candidates: List[Any] = [] + for node in list(gm.graph.nodes): - if node.op != "call_function" or node.target not in af_targets: - continue - if not node.args or not hasattr(node.args[0], "_schema"): - continue - op_overload = node.args[0] - if op_overload not in DYNAMO_CONVERTERS: - # No converter for the underlying op — leave it functionalized. + if node.op != "call_function": continue + if node.target in af_targets: + if not node.args or not hasattr(node.args[0], "_schema"): + continue + op_overload = node.args[0] + has_converter = converter_check_cache.get(op_overload) + if has_converter is None: + has_converter = op_overload in DYNAMO_CONVERTERS + converter_check_cache[op_overload] = has_converter + if has_converter: + hops_to_rewrite.append(node) + elif ( + node.target is aten_copy_ + and len(node.args) >= 2 + and getattr(node.args[0], "op", None) == "placeholder" + ): + copy_candidates.append(node) + + if not hops_to_rewrite: + return gm + for node in hops_to_rewrite: + op_overload = node.args[0] op_args = _reconstruct_op_args(op_overload, dict(node.kwargs)) + n_outputs = len(op_overload._schema.returns) + hop_val = node.meta.get("val") bases = node.kwargs.get("_all_bases", []) with gm.graph.inserting_before(node): new_call = gm.graph.call_function(op_overload, args=tuple(op_args)) - # Propagate the op's own meta["val"] (first element of the HOP's - # tuple) so downstream shape extraction has what it needs. - hop_val = node.meta.get("val") if isinstance(hop_val, tuple) and len(hop_val) >= 1: - new_call.meta["val"] = hop_val[0] + if n_outputs == 1: + new_call.meta["val"] = hop_val[0] + else: + new_call.meta["val"] = tuple(hop_val[:n_outputs]) if "tensor_meta" in node.meta: new_call.meta["tensor_meta"] = node.meta["tensor_meta"] - # Rewrite getitem users of the HOP. Both tuple slots — index 0 (the - # op's actual return) and indices >= 1 (post-mutation bases) — get - # routed to the direct op call. Routing index-k users to the original - # base placeholder would point downstream nodes at the pre-mutation - # FX value (and would also strand `new_call` as dead code), so we - # anchor everything to the new call instead. - getitem_users = [u for u in list(node.users) if u.target is _operator_getitem()] - for user in getitem_users: - user.replace_all_uses_with(new_call) - gm.graph.erase_node(user) - - # If the HOP has any remaining users (unusual — would mean someone - # consumed the tuple directly), bail rather than leave a dangling - # reference to the now-stale HOP node. + # For single-output ops the op's return and every post-mutation base + # are the same tensor (the in-place result), so routing all getitem + # users to ``new_call`` is correct and keeps it alive. For + # multi-output ops we materialize one getitem per return and route + # base-slot users to the corresponding base placeholder — the + # mutation has already been applied in place by ``new_call``. + getitem_users = [u for u in list(node.users) if u.target is operator.getitem] + if n_outputs == 1: + for user in getitem_users: + user.replace_all_uses_with(new_call) + gm.graph.erase_node(user) + else: + return_getitems: List[Any] = [] + with gm.graph.inserting_after(new_call): + for i in range(n_outputs): + g = gm.graph.call_function( + operator.getitem, args=(new_call, i) + ) + if isinstance(hop_val, tuple) and i < len(hop_val): + g.meta["val"] = hop_val[i] + return_getitems.append(g) + for user in getitem_users: + idx = user.args[1] + if idx < n_outputs: + user.replace_all_uses_with(return_getitems[idx]) + else: + base_idx = idx - n_outputs + user.replace_all_uses_with(bases[base_idx]) + gm.graph.erase_node(user) + if list(node.users): raise RuntimeError( f"auto_functionalized_v2 node {node.name} has non-getitem users" f" {list(node.users)}; cannot un-functionalize safely." ) gm.graph.erase_node(node) - modified = True - - if not modified: - return gm - # `copy_(base, op_return)` was synthesized by functionalization to write - # the mutation back through the base placeholder. With the direct mutating - # call restored, the buffer is already mutated by the kernel — keeping the - # copy_ would leave an unsupported mutating op blocking partitioning. Drop - # it and route its users straight to the op return. - for node in list(gm.graph.nodes): - if ( - node.op == "call_function" - and node.target is torch.ops.aten.copy_.default - and len(node.args) >= 2 - and node.args[0].op == "placeholder" - ): - node.replace_all_uses_with(node.args[1]) - gm.graph.erase_node(node) + # Functionalization adds copy_(base, op_return) to write the mutation + # back through the placeholder. The direct call already mutates the + # buffer, so the copy_ is redundant and would block partitioning. + for node in copy_candidates: + node.replace_all_uses_with(node.args[1]) + gm.graph.erase_node(node) gm.graph.lint() gm.recompile() logger.debug(f"Un-functionalized QDP in-place ops:\n{gm.graph}") return gm - - -def _operator_getitem() -> Any: - import operator - - return operator.getitem diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 98e9ce36d3..3ddb812b9f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -437,23 +437,19 @@ def setup_engine(self) -> None: for input_name in self.input_names } - # Engines containing QDP plugins with aliased I/O (declared via - # ``TensorDesc.aliased()`` and built with the - # ``ALIASED_PLUGIN_IO_10_03`` preview feature) expect the output - # binding for an aliased output and its corresponding input binding - # to point at the same buffer at runtime. ``create_output_tensors`` - # otherwise allocates a fresh output buffer, which both breaks the - # in-place semantics (user's input is not mutated) and produces - # unspecified output values. Query the engine once at setup so the - # forward pass can reuse the user's input tensor for any aliased - # output. ``get_aliased_input_tensor`` returns the input tensor's - # name when aliased, ``None`` otherwise. - self.aliased_output_to_input = {} + # For QDP plugins with aliased I/O, the output binding must share the + # input binding's buffer at runtime; otherwise the in-place mutation is + # lost. Resolve the alias mapping to (output_idx -> input_idx) once so + # the forward path can rebind without per-call name lookups. + self.aliased_output_idx_to_input_idx: Dict[int, int] = {} if hasattr(self.engine, "get_aliased_input_tensor"): - for output_name in self.output_names: + input_name_to_idx = {n: i for i, n in enumerate(self.input_names)} + for out_idx, output_name in enumerate(self.output_names): aliased_input_name = self.engine.get_aliased_input_tensor(output_name) if aliased_input_name: - self.aliased_output_to_input[output_name] = aliased_input_name + self.aliased_output_idx_to_input_idx[out_idx] = input_name_to_idx[ + aliased_input_name + ] def _setup_runtime_config(self) -> None: """Create a RuntimeConfig with runtime cache for TensorRT-RTX. @@ -746,21 +742,13 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) outputs = self.create_output_tensors() - # Aliased outputs (QDP in-place plugins) must share the - # input's storage rather than the freshly-allocated - # buffer ``create_output_tensors`` produced. Swap the - # corresponding entry in ``outputs`` to point at the - # user's input tensor so the in-place mutation surfaces - # back to the caller and so ``set_tensor_address`` below - # binds the same buffer to both bindings. - if self.aliased_output_to_input: - input_by_name = dict(zip(self.input_names, contiguous_inputs)) - for o, output_name in enumerate(self.output_names): - aliased_input_name = self.aliased_output_to_input.get( - output_name - ) - if aliased_input_name is not None: - outputs[o] = input_by_name[aliased_input_name] + # Rebind aliased outputs to their paired input buffer so + # the in-place mutation lands in the caller's tensor. + for ( + out_idx, + in_idx, + ) in self.aliased_output_idx_to_input_idx.items(): + outputs[out_idx] = contiguous_inputs[in_idx] for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py index 03dfa1f795..609ba1c075 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py @@ -13,9 +13,6 @@ def add_one_inplace(X: torch.Tensor) -> torch.Tensor: assert X.is_cuda X.add_(1) - # torch.library forbids returning an input directly; clone to satisfy the - # no-alias constraint while still letting the registered fake (which - # returns X by identity) signal aliasing for the TRT plugin descriptor. return X.clone() @@ -39,11 +36,10 @@ def _(X: torch.Tensor) -> torch.Tensor: "QDP Plugin is not available", ) class TestInplacePlugin(unittest.TestCase): - """The standard DispatchTestCase.run_test passes the same input tensor to - eager and TRT, which double-applies the mutation for in-place ops. Use a - bespoke flow with cloned inputs and verify both that the output matches the - expected post-mutation value AND that the input buffer was mutated in - place.""" + """In-place ops mutate their input, so DispatchTestCase.run_test (which + feeds the same tensor to eager and TRT) double-applies the mutation. We + use cloned inputs and check both the return value and the in-place write. + """ @parameterized.expand( [ @@ -73,11 +69,9 @@ def forward(self, x): immutable_weights=True, ) - # Guard against regressing to PyTorch fallback: if the un-functionalize - # pass stops restoring the mutating op, the partitioner finds 0 - # supported ops and the "compiled" module is just PyTorch eager — the - # test would still pass on value because the eager op mutates the - # input itself, masking the real failure. + # Guard against a silent fallback to pure-PyTorch: the eager op + # already mutates the input, so output-only checks pass even when no + # TRT engine was built. from torch_tensorrt.dynamo.runtime import ( PythonTorchTensorRTModule, TorchTensorRTModule, @@ -96,9 +90,8 @@ def forward(self, x): trt_out = compiled(trt_input) torch.testing.assert_close(trt_out, expected_post, rtol=1e-5, atol=1e-5) - # The whole point of aliased plugin I/O is that the input buffer is - # mutated in place by the engine. If the engine had allocated a fresh - # output buffer, `trt_input` would still hold the pre-call values. + # Aliased plugin I/O is only active if the engine mutated trt_input; + # a fresh-output engine would leave it at its pre-call values. torch.testing.assert_close(trt_input, expected_post, rtol=1e-5, atol=1e-5) diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py new file mode 100644 index 0000000000..d441bdb467 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py @@ -0,0 +1,87 @@ +"""Single-output aliased plugin whose output is consumed by another TRT layer. + +This is the realistic production pattern (e.g. a KV-cache update whose +post-update tensor is read by an attention layer in the same engine). +""" + +import platform +import unittest + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_one_inplace_consumed", mutates_args=("X",) +) # type: ignore[misc] +def add_one_inplace_consumed(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace_consumed") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace_consumed", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePluginConsumed(unittest.TestCase): + def test_aliased_output_consumed_downstream(self): + class Model(nn.Module): + def forward(self, x): + a = torch.ops.torchtrt_ex.add_one_inplace_consumed.default(x) + return a * 2 + + x_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + expected_post = x_base + 1 + expected = expected_post * 2 + + x_trt = x_base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[x_trt.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + result = compiled(x_trt) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(x_trt, expected_post, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py new file mode 100644 index 0000000000..2b3b04886e --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py @@ -0,0 +1,90 @@ +"""Aliased plugin I/O combined with dynamic shapes — the production case for +KV-cache-style ops where the cache tensor's batch dim varies.""" + +import platform +import unittest + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_one_inplace_dyn", mutates_args=("X",) +) # type: ignore[misc] +def add_one_inplace_dyn(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace_dyn") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace_dyn", supports_dynamic_shapes=True + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePluginDynamicShapes(unittest.TestCase): + def test_dynamic_batch(self): + class Model(nn.Module): + def forward(self, x): + return torch.ops.torchtrt_ex.add_one_inplace_dyn.default(x) + + compile_input = torch.randn(8, 32, device="cuda", dtype=torch.float) + compiled = torch_tensorrt.compile( + Model(), + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 32), + opt_shape=(8, 32), + max_shape=(16, 32), + dtype=torch.float, + ) + ], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + for batch in (1, 4, 16): + base = torch.randn(batch, 32, device="cuda", dtype=torch.float) + expected = base + 1 + trt_input = base.clone() + trt_out = compiled(trt_input) + torch.testing.assert_close(trt_out, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(trt_input, expected, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py new file mode 100644 index 0000000000..b3e1359186 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py @@ -0,0 +1,106 @@ +"""Multi-input partial-mutation + multi-output coverage. + +Exercises the un-functionalize pass's multi-output branch, the alias-map +build in ``_generate_plugin._generic_plugin_desc`` for ops where only one +input is mutated, and the JIT impl's aliased-output ``copy_`` filter. + +Only the *fresh* output is returned by the model. The aliased output is +unused. This is deliberate: TRT's preview-feature ``ALIASED_PLUGIN_IO_10_03`` +inserts a defensive copy that breaks aliasing when a multi-output plugin's +aliased output is consumed by another TRT layer in the same engine. The +correctness-critical path the test covers is the multi-output plumbing +itself; coverage for "aliased output consumed downstream" is provided by +the single-output test (which TRT handles correctly). +""" + +import platform +import unittest +from typing import Tuple + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_inplace_two_out", mutates_args=("X",) +) # type: ignore[misc] +def add_inplace_two_out( + X: torch.Tensor, Y: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + assert X.is_cuda and Y.is_cuda + X.add_(Y) + return X.clone(), X * 2 + + +@torch.library.register_fake("torchtrt_ex::add_inplace_two_out") +def _(X: torch.Tensor, Y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return X, torch.empty_like(X) + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_inplace_two_out", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestMultiOutputInplacePlugin(unittest.TestCase): + def test_partial_mutation_fresh_output(self): + class Model(nn.Module): + def forward(self, x, y): + _, b = torch.ops.torchtrt_ex.add_inplace_two_out.default(x, y) + return b + + x_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + y_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + + x_eager = x_base.clone() + _, eager_b = add_inplace_two_out(x_eager, y_base.clone()) + expected_x = x_base + y_base + expected_b = expected_x * 2 + torch.testing.assert_close(x_eager, expected_x, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(eager_b, expected_b, rtol=1e-5, atol=1e-5) + + x_trt = x_base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[x_trt.clone(), y_base.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + result = compiled(x_trt, y_base.clone()) + torch.testing.assert_close(result, expected_b, rtol=1e-5, atol=1e-5) + # X was mutated in place; Y was not. + torch.testing.assert_close(x_trt, expected_x, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests()