diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7293269961..87ac3cbcd0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -202,8 +202,8 @@ def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: def validate_compile_settings(self) -> None: if ENABLED_FEATURES.tensorrt_rtx: - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - raise RuntimeError("TensorRT-RTX does not support bfloat16!") + # NOTE: bfloat16 check disabled — depthwise conv BF16 limitation + # is now handled per-layer via capability_validator return if ( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67b499c068..dbd49452ea 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -6,6 +6,7 @@ import numpy as np import torch +from tensorrt import ITensor as TRTTensor from torch.fx.node import Argument, Node, Target from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._features import needs_not_tensorrt_rtx @@ -27,8 +28,6 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from tensorrt import ITensor as TRTTensor - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2755,8 +2754,39 @@ def aten_ops_le( ) +def depthwise_bf16_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + """Reject depthwise conv/deconv with BF16 on TensorRT-RTX. + + TensorRT-RTX does not support depthwise convolutions in BF16. Returning + False causes the partitioner to fall back to PyTorch for these specific + nodes, while all other convolutions remain on TRT. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True + # Check if the input tensor is BF16 (via FX node metadata) + input_node = node.args[0] + input_meta = getattr(input_node, "meta", {}).get("tensor_meta") + if input_meta is None or input_meta.dtype != torch.bfloat16: + return True + groups = args_bounds_check(node.args, 8) + if groups is not None and groups > 1: + weight_node = node.args[1] + weight_meta = getattr(weight_node, "meta", {}).get("tensor_meta") + if weight_meta is not None and groups == weight_meta.shape[0]: + _LOGGER.debug( + "Depthwise convolution '%s' with BF16 is not supported on " + "TensorRT-RTX. Falling back to PyTorch for this layer.", + node.name, + ) + return False + return True + + @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, + capability_validator=depthwise_bf16_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index 16b82b9858..d7c7a554c0 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -237,10 +237,6 @@ def forward(self, x, y): if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"] ] ) - @unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", - ) def test_elementwise_ops_bf16(self, _, orig_op): class TestModule(nn.Module): def __init__(self, orig_op): diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 62920c9610..cb79001f4f 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -67,10 +67,6 @@ def forward(self, x): precision=torch.float, ) - @unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", - ) def test_to_copy_bfloat16(self): class ToCopyBFloat16(nn.Module): def forward(self, x): diff --git a/tests/py/dynamo/llm/test_llm_models.py b/tests/py/dynamo/llm/test_llm_models.py index d08ad8d84e..88c7a862f1 100644 --- a/tests/py/dynamo/llm/test_llm_models.py +++ b/tests/py/dynamo/llm/test_llm_models.py @@ -28,8 +28,6 @@ def test_llm_decoder_layer(precision): from run_llm import compile_torchtrt from torchtrt_ext import register_sdpa - if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16": - pytest.skip("TensorRT-RTX does not support bfloat16, skipping test") with torch.inference_mode(): args = argparse.Namespace() args.debug = False diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 6c02db6b68..42507968f7 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -200,10 +200,6 @@ def forward(self, x): ), "Platform does not have BF16 support", ) -@unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) class TestBF16Support(TestCase): @unittest.skipIf( not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 28d72433b7..b54e67c2fb 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -189,9 +189,6 @@ def test_resnet_dynamic(ir, dtype): """ Tests the Resnet18 model (which is fully convertible) with dynamic shapes """ - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - import torchvision.models as models model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype) diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b1435540e0..4c0f2ccef4 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -197,9 +197,6 @@ def test_resnet18_torch_exec_ops(ir): "torchvision is not installed", ) def test_mobilenet_v2(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype) input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype) @@ -239,9 +236,6 @@ def test_mobilenet_v2(ir, dtype): "timm or torchvision not installed", ) def test_efficientnet_b0(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - model = ( timm.create_model("efficientnet_b0", pretrained=True) .eval() @@ -286,9 +280,6 @@ def test_efficientnet_b0(ir, dtype): "transformers is required to run this test", ) def test_bert_base_uncased(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - from transformers import BertModel model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype) @@ -430,10 +421,6 @@ def test_resnet18_half(ir): @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "tensorrt_rtx does not support bfloat16", -) def test_cosmos_true_div(ir): class CosmosLearnablePositionalEmbed(torch.nn.Module): def __init__( @@ -532,10 +519,6 @@ def forward( @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) @pytest.mark.critical def test_bf16_model(ir): class MyModule(torch.nn.Module): @@ -581,10 +564,6 @@ def forward(self, x): @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) @pytest.mark.critical def test_bf16_fallback_model(ir): class MyModule(torch.nn.Module):