TRT 10.x Myelin compiler: ForeignNode[Conv+Split] has no valid tactic despite sufficient workspace
Description
TensorRT 10.x fails to compile an RF-DETR segmentation model's CSPRepLayer projector. The Myelin compiler backend fuses Conv → ReLU → Split into a ForeignNode with a single tactic (0x0000000000000000) that always reports "insufficient memory" — even when the requested size is 4 MB and 8 GB of workspace is available. No fallback tactic exists, so the build fails unconditionally.
The same ONNX model compiles and runs correctly on TRT 8.6.2.
Reproduces on both x86 (A10G) and Jetson AGX Orin across TRT 10.3.0 and 10.7.0.
Related: #4723 documents another TRT 10.x regression for DINOv2/v3 ViT models (FP16 softmax overflow in decomposed attention). Our issue is in the projector subgraph (not attention) and is a compilation failure rather than an accuracy bug.
Environment
Tested on two platforms:
Platform 1: Jetson AGX Orin
TensorRT Version: 10.3.0
NVIDIA GPU: Jetson AGX Orin 64GB iGPU (SM 8.7)
CUDA Version: 12.6 (JetPack 6)
CUDNN Version: 9.x (JetPack 6)
Operating System: Linux (JetPack 6 / L4T)
Python Version (if applicable): 3.12
PyTorch Version (if applicable): 2.8.0 (used for ONNX export only)
Baremetal or Container (if so, version): nvcr.io/nvidia/tritonserver:24.12-py3-igpu
Platform 2: Cloud A10G
TensorRT Version: 10.7.0
NVIDIA GPU: NVIDIA A10G 24GB (SM 8.6)
CUDA Version: 12.6
CUDNN Version: 9.x
Operating System: Linux
Python Version (if applicable): 3.12
PyTorch Version (if applicable): 2.8.0 (used for ONNX export only)
Baremetal or Container (if so, version): nvcr.io/nvidia/tensorrt:24.12-py3
Relevant Files
Model link: ONNX model too large to attach (RF-DETR Seg2XLarge, ~134 MB, static input [1, 3, 768, 768], opset 17, preprocessed with onnxsim).
To reproduce from scratch using only public files:
pip install rfdetr onnxsim
python -c "
from rfdetr import RFDETRSeg2XLarge
model = RFDETRSeg2XLarge() # downloads default public checkpoint from https://storage.googleapis.com/rfdetr/rf-detr-seg-2xl-ft.pth
model.export(output_file='rfdetr_seg2xl.onnx')
"
onnxsim rfdetr_seg2xl.onnx rfdetr_seg2xl_sim.onnx
The failing subgraph is a CSPRepLayer (Cross Stage Partial) projector, common in YOLO-family and detection transformer architectures:
Conv(in=256, out=256, kernel=3x3, stride=1, pad=1)
→ ReLU
→ Split(axis=1, splits=[128, 128])
Input spatial dimensions: 64×64. After onnxsim, BatchNorm is folded into Conv (the tensor name cv1.bn.weight is a remnant from the original BN weight that was absorbed).
The Myelin compiler creates a single ForeignNode spanning Conv→ReLU→Split with exactly one tactic (0x0000000000000000) that unconditionally fails.
Steps To Reproduce
Commands or scripts:
# 1. Download the ONNX model (attached)
# 2. Compile with trtexec (fails on both FP16 and FP32)
trtexec \
--onnx=inference_model_sim.onnx \
--saveEngine=model.engine \
--fp16 \
--memPoolSize=workspace:8192MiB \
--verbose
Full error output — Jetson AGX Orin (TRT 10.3.0):
&&&& RUNNING TensorRT.trtexec [TensorRT v100300] # trtexec --onnx=/tmp/rfdetr_model_sim.onnx --saveEngine=/tmp/rfdetr_sim_fp16.engine --fp16 --memPoolSize=workspace:8192MiB --verbose
[04/03/2026-14:33:22] [V] [TRT] =============== Computing costs for
{ForeignNode[backbone.0.projector.stages.0.0.cv1.bn.weight +
ONNXTRT_Broadcast_884.../backbone/backbone.0/projector/stages.0/stages.0.0/Split_889]}
[04/03/2026-14:33:22] [V] [TRT] --------------- Timing Runner: {ForeignNode[...Split_889]} (Myelin[0x80000023])
[04/03/2026-14:33:22] [W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 8404992 detected for tactic 0x0000000000000000.
[04/03/2026-14:33:22] [V] [TRT] Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[04/03/2026-14:33:22] [V] [TRT] --------------- Timing Runner: {ForeignNode[...Split_889]} (Myelin[0x80000023])
[04/03/2026-14:33:22] [W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 6307840 detected for tactic 0x0000000000000000.
[04/03/2026-14:33:22] [V] [TRT] Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[04/03/2026-14:33:22] [V] [TRT] --------------- Timing Runner: {ForeignNode[...Split_889]} (Myelin[0x80000023])
[04/03/2026-14:33:22] [W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 4210688 detected for tactic 0x0000000000000000.
[04/03/2026-14:33:22] [V] [TRT] Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[04/03/2026-14:33:22] [E] Error[4]: IBuilder::buildSerializedNetwork: Error Code 4: Internal Error
(Could not find any implementation for node
{ForeignNode[backbone.0.projector.stages.0.0.cv1.bn.weight +
ONNXTRT_Broadcast_884.../backbone/backbone.0/projector/stages.0/stages.0.0/Split_889]}
due to insufficient workspace.)
&&&& FAILED TensorRT.trtexec [TensorRT v100300]
Note: tactic requests 4–8 MB but fails with 8 GB workspace available. This is not a real memory issue.
Full error output — NVIDIA A10G (TRT 10.7.0):
[W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 832086912 detected for tactic 0x0000000000000000.
[W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 8404992 detected for tactic 0x0000000000000000.
[W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 6307840 detected for tactic 0x0000000000000000.
[W] [TRT] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory
on requested size of 4210688 detected for tactic 0x0000000000000000.
[E] Error[4]: IBuilder::buildSerializedNetwork: Error Code 4: Internal Error
(Could not find any implementation for node
{ForeignNode[backbone.0.projector.stages.0.0.cv1.bn.weight +
ONNXTRT_Broadcast_884.../backbone/backbone.0/projector/stages.0/stages.0.0/Split_889]}
due to insufficient workspace.)
Same failure pattern: 16 GB workspace available, tactic unconditionally rejected.
Workarounds attempted (all failed):
| Workaround |
Result |
onnxsim BN folding (eliminates standalone BN) |
Same ForeignNode — Conv absorbed BN weights |
| Identity fusion-breaker nodes after Conv |
TRT/Myelin optimizes through Identity |
Mul(x, 1.0) fusion-breaker nodes |
TRT/Myelin still creates same ForeignNode |
Replace Split with equivalent Slice ops |
TRT fuses Conv+Slice identically |
--builderOptimizationLevel=0 |
Same ForeignNode, Error[10] instead of [4] |
--builderOptimizationLevel=3 |
Same failure |
--noTF32 |
Same failure |
FP32 only (no --fp16) |
Same failure |
--tacticSources=-JIT_CONVOLUTIONS |
Same failure (Myelin is not a tactic source) |
--tacticSources=+CUBLAS,+CUBLAS_LT,+CUDNN |
Same failure |
There appears to be no way to disable the Myelin compiler backend via trtexec flags or --tacticSources.
Have you tried the latest release?: Tested on TRT 10.3.0 (JetPack 6) and TRT 10.7.0 (24.12 container). Both fail identically.
Can this model run on other frameworks? Yes. The ONNX model runs correctly on:
- ONNX Runtime (CPU and GPU)
- TensorRT 8.6.2 (
nvcr.io/nvidia/tritonserver:24.07-py3-igpu on Orin) — compiles and runs at ~8.6 FPS FP16
The regression is specific to TRT 10.x.
TRT 10.x Myelin compiler: ForeignNode[Conv+Split] has no valid tactic despite sufficient workspace
Description
TensorRT 10.x fails to compile an RF-DETR segmentation model's CSPRepLayer projector. The Myelin compiler backend fuses
Conv → ReLU → Splitinto aForeignNodewith a single tactic (0x0000000000000000) that always reports "insufficient memory" — even when the requested size is 4 MB and 8 GB of workspace is available. No fallback tactic exists, so the build fails unconditionally.The same ONNX model compiles and runs correctly on TRT 8.6.2.
Reproduces on both x86 (A10G) and Jetson AGX Orin across TRT 10.3.0 and 10.7.0.
Related: #4723 documents another TRT 10.x regression for DINOv2/v3 ViT models (FP16 softmax overflow in decomposed attention). Our issue is in the projector subgraph (not attention) and is a compilation failure rather than an accuracy bug.
Environment
Tested on two platforms:
Platform 1: Jetson AGX Orin
TensorRT Version: 10.3.0
NVIDIA GPU: Jetson AGX Orin 64GB iGPU (SM 8.7)
CUDA Version: 12.6 (JetPack 6)
CUDNN Version: 9.x (JetPack 6)
Operating System: Linux (JetPack 6 / L4T)
Python Version (if applicable): 3.12
PyTorch Version (if applicable): 2.8.0 (used for ONNX export only)
Baremetal or Container (if so, version):
nvcr.io/nvidia/tritonserver:24.12-py3-igpuPlatform 2: Cloud A10G
TensorRT Version: 10.7.0
NVIDIA GPU: NVIDIA A10G 24GB (SM 8.6)
CUDA Version: 12.6
CUDNN Version: 9.x
Operating System: Linux
Python Version (if applicable): 3.12
PyTorch Version (if applicable): 2.8.0 (used for ONNX export only)
Baremetal or Container (if so, version):
nvcr.io/nvidia/tensorrt:24.12-py3Relevant Files
Model link: ONNX model too large to attach (RF-DETR Seg2XLarge, ~134 MB, static input
[1, 3, 768, 768], opset 17, preprocessed withonnxsim).To reproduce from scratch using only public files:
The failing subgraph is a CSPRepLayer (Cross Stage Partial) projector, common in YOLO-family and detection transformer architectures:
Input spatial dimensions: 64×64. After
onnxsim, BatchNorm is folded into Conv (the tensor namecv1.bn.weightis a remnant from the original BN weight that was absorbed).The Myelin compiler creates a single ForeignNode spanning Conv→ReLU→Split with exactly one tactic (
0x0000000000000000) that unconditionally fails.Steps To Reproduce
Commands or scripts:
Full error output — Jetson AGX Orin (TRT 10.3.0):
Note: tactic requests 4–8 MB but fails with 8 GB workspace available. This is not a real memory issue.
Full error output — NVIDIA A10G (TRT 10.7.0):
Same failure pattern: 16 GB workspace available, tactic unconditionally rejected.
Workarounds attempted (all failed):
onnxsimBN folding (eliminates standalone BN)Mul(x, 1.0)fusion-breaker nodesSplitwith equivalentSliceops--builderOptimizationLevel=0--builderOptimizationLevel=3--noTF32--fp16)--tacticSources=-JIT_CONVOLUTIONS--tacticSources=+CUBLAS,+CUBLAS_LT,+CUDNNThere appears to be no way to disable the Myelin compiler backend via
trtexecflags or--tacticSources.Have you tried the latest release?: Tested on TRT 10.3.0 (JetPack 6) and TRT 10.7.0 (24.12 container). Both fail identically.
Can this model run on other frameworks? Yes. The ONNX model runs correctly on:
nvcr.io/nvidia/tritonserver:24.07-py3-igpuon Orin) — compiles and runs at ~8.6 FPS FP16The regression is specific to TRT 10.x.