Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool:

def validate_compile_settings(self) -> None:
if ENABLED_FEATURES.tensorrt_rtx:
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
raise RuntimeError("TensorRT-RTX does not support bfloat16!")
# NOTE: bfloat16 check disabled — depthwise conv BF16 limitation
# is now handled per-layer via capability_validator
return

if (
Expand Down
34 changes: 32 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
Expand All @@ -27,8 +28,6 @@
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM

from tensorrt import ITensor as TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -2694,8 +2693,39 @@ def aten_ops_le(
)


def depthwise_bf16_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
"""Reject depthwise conv/deconv with BF16 on TensorRT-RTX.

TensorRT-RTX does not support depthwise convolutions in BF16. Returning
False causes the partitioner to fall back to PyTorch for these specific
nodes, while all other convolutions remain on TRT.
"""
if not ENABLED_FEATURES.tensorrt_rtx:
return True
# Check if the input tensor is BF16 (via FX node metadata)
input_node = node.args[0]
input_meta = getattr(input_node, "meta", {}).get("tensor_meta")
if input_meta is None or input_meta.dtype != torch.bfloat16:
return True
groups = args_bounds_check(node.args, 8)
if groups is not None and groups > 1:
weight_node = node.args[1]
weight_meta = getattr(weight_node, "meta", {}).get("tensor_meta")
if weight_meta is not None and groups == weight_meta.shape[0]:
_LOGGER.debug(
"Depthwise convolution '%s' with BF16 is not supported on "
"TensorRT-RTX. Falling back to PyTorch for this layer.",
node.name,
)
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=depthwise_bf16_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/llm/test_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def test_llm_decoder_layer(precision):
from run_llm import compile_torchtrt
from torchtrt_ext import register_sdpa

if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16":
pytest.skip("TensorRT-RTX does not support bfloat16, skipping test")
with torch.inference_mode():
args = argparse.Namespace()
args.debug = False
Expand Down
3 changes: 0 additions & 3 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def test_resnet_dynamic(ir, dtype):
"""
Tests the Resnet18 model (which is fully convertible) with dynamic shapes
"""
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

import torchvision.models as models

model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype)
Expand Down
13 changes: 0 additions & 13 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ def test_resnet18_torch_exec_ops(ir):
"torchvision is not installed",
)
def test_mobilenet_v2(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype)
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)

Expand Down Expand Up @@ -239,9 +236,6 @@ def test_mobilenet_v2(ir, dtype):
"timm or torchvision not installed",
)
def test_efficientnet_b0(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

model = (
timm.create_model("efficientnet_b0", pretrained=True)
.eval()
Expand Down Expand Up @@ -286,9 +280,6 @@ def test_efficientnet_b0(ir, dtype):
"transformers is required to run this test",
)
def test_bert_base_uncased(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype)
Expand Down Expand Up @@ -430,10 +421,6 @@ def test_resnet18_half(ir):


@pytest.mark.unit
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"tensorrt_rtx does not support bfloat16",
)
def test_cosmos_true_div(ir):
class CosmosLearnablePositionalEmbed(torch.nn.Module):
def __init__(
Expand Down
Loading