From 745bc494d06501f0b79360ef725e1621350af050 Mon Sep 17 00:00:00 2001 From: dmoodie Date: Thu, 14 May 2026 15:30:08 -0400 Subject: [PATCH 1/4] Bugfix for using trtexec remote auto tuning for qdq autotune, need to push engine to the remote device for correct latency testing Signed-off-by: dmoodie --- docs/source/guides/9_autotune.rst | 4 + .../onnx/quantization/autotune/benchmark.py | 324 ++++++-- .../quantization/autotune/test_benchmark.py | 4 +- .../autotune/test_cli_pipeline.py | 211 +++++ .../autotune/test_trtexec_benchmark.py | 728 ++++++++++++++++++ 5 files changed, 1226 insertions(+), 45 deletions(-) create mode 100644 tests/unit/onnx/quantization/autotune/test_cli_pipeline.py create mode 100644 tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py diff --git a/docs/source/guides/9_autotune.rst b/docs/source/guides/9_autotune.rst index 583dfcb6ee8..0a3e1e52b6e 100644 --- a/docs/source/guides/9_autotune.rst +++ b/docs/source/guides/9_autotune.rst @@ -243,6 +243,10 @@ To use remote autotuning during Q/DQ placement optimization, run with ``trtexec` * Valid remote autotuning configuration * ``--use_trtexec`` must be set (benchmarking uses ``trtexec`` instead of the TensorRT Python API) * ``--safe --skipInference`` must be enabled via ``--trtexec_benchmark_args`` +* ssh and scp must be available on the local machine +* sshpass must be available on the local machine if using password authentication +* Only once instance of remote auto tuning can be run at a time since the remote timing server and latency measurement processes share the GPU but do not coordinate execution; thus latency measurements would not be accurate if multiple instances are run concurrently. +* useCudaGraph will be added for latency measurement to improve accuracy. Replace ```` with an actual remote autotuning configuration string (see ``trtexec --help`` for more details). Other TensorRT benchmark options (e.g. ``--timing_cache``, ``--warmup_runs``, ``--timing_runs``, ``--plugin_libraries``) are also available; run ``--help`` for details. diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index b87478a1572..8e4922c2928 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -30,12 +30,16 @@ import importlib.util import os import re +import shlex import shutil +import subprocess # nosec B404 import tempfile import time from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Any +from urllib.parse import parse_qs, urlparse import numpy as np import torch @@ -144,6 +148,185 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None: self.logger.warning(f"Failed to save logs to {file}: {e}") +_SAFE_PATTERN = ( + r"\[\d{2}/\d{2}/\d{4}-\d{2}:\d{2}:\d{2}\]\s+\[I\]\s+" + r"Average over \d+ runs - GPU latency:\s*([\d.]+)\s*ms" +) +_STD_PATTERN = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms" + +_URL_PASSWORD_RE = re.compile(r"(://[^:/?#@]+):[^@/?#]+@") + + +def _redact_url_password(s: str) -> str: + """Replace any ``scheme://user:password@host`` substring with ``user:******@host``. + + Used so SSH passwords supplied via ``--remoteAutoTuningConfig`` don't leak + into log messages or exception strings. + """ + return _URL_PASSWORD_RE.sub(r"\1:******@", s) + + +def _build_base_trtexec_cmd( + *, + timing_runs: int, + warmup_runs: int, + engine_path: str, + timing_cache_file: str, + plugin_libraries: list[str] | None = None, + log: Any = None, +) -> list[str]: + """Build the static portion of the trtexec command line (no ``--onnx=`` yet). + + Plugin libraries that don't exist on disk are skipped with a warning if a + logger is supplied. The leading ``trtexec`` binary path is not included — + the caller is responsible for prepending it. + + Args: + timing_runs: Value for ``--avgRuns`` and ``--iterations``. + warmup_runs: Value for ``--warmUp``. + engine_path: Path used for ``--saveEngine=``. + timing_cache_file: Path used for ``--timingCacheFile=``. + plugin_libraries: Paths to ``.so`` libraries for ``--staticPlugins``. + log: Optional logger used to warn about missing plugins and trace adds. + """ + cmd = [ + f"--avgRuns={timing_runs}", + f"--iterations={timing_runs}", + f"--warmUp={warmup_runs}", + "--stronglyTyped", + f"--saveEngine={engine_path}", + f"--timingCacheFile={timing_cache_file}", + ] + for plugin_lib in plugin_libraries or []: + plugin_path = Path(plugin_lib).resolve() + if not plugin_path.exists(): + if log is not None: + log.warning(f"Plugin library not found: {plugin_path}") + continue + cmd.append(f"--staticPlugins={plugin_path}") + if log is not None: + log.debug(f"Added plugin library: {plugin_path}") + return cmd + + +def _extract_remote_config_value(trtexec_args: list[str], *, log: Any = None) -> str | None: + """Find the value of ``--remoteAutoTuningConfig`` in ``trtexec_args``. + + Supports both inline (``--remoteAutoTuningConfig=value``) and split + (``--remoteAutoTuningConfig value``) forms. + + Returns: + The value as a string, or ``None`` if the flag is absent. Returning + an empty string is possible (e.g. ``--remoteAutoTuningConfig=``); the + caller decides whether to treat that as an error. + + Raises: + ValueError: If the flag appears more than once, has no value at the + end of the list, or is malformed (e.g. missing the ``=`` + separator). SSH passwords in malformed args are redacted before + being included in the error or debug log. + """ + matches = [a for a in trtexec_args if "--remoteAutoTuningConfig" in a] + if not matches: + return None + if len(matches) != 1: + raise ValueError("Exactly one --remoteAutoTuningConfig argument is required") + + for i, arg in enumerate(trtexec_args): + if not arg.startswith("--remoteAutoTuningConfig"): + continue + if arg == "--remoteAutoTuningConfig": + if i + 1 >= len(trtexec_args): + raise ValueError("Missing value for --remoteAutoTuningConfig") + return trtexec_args[i + 1] + if arg.startswith("--remoteAutoTuningConfig="): + return arg.split("=", 1)[1] + # Malformed: starts with the flag name but neither uses ``=`` nor is + # the bare flag. Redact any embedded SSH password before surfacing. + redacted_arg = _redact_url_password(arg) + if log is not None: + log.debug(f"Parsing remoteAutoTuningConfig arg: {redacted_arg}") + raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {redacted_arg}") + return None # pragma: no cover — unreachable; ``matches`` proved presence + + +@dataclass(frozen=True) +class _RemoteAutotuningConfig: + """Resolved remote-autotuning destination parsed from a ``ssh://`` URL.""" + + user: str + password: str # may be empty when no password was supplied + ip: str + port: int + options: dict[str, str] + bin_path: str # dirname of ``remote_exec_path`` + lib_path: str # value of ``remote_lib_path`` + + +def _parse_remote_autotuning_url(url: str) -> _RemoteAutotuningConfig: + """Parse a ``--remoteAutoTuningConfig`` URL into structured fields. + + Required URL form:: + + ssh://user[:password]@host[:port]?remote_exec_path=PATH&remote_lib_path=PATH + + Raises: + ValueError: If the scheme is not ``ssh://``; if user or host are + missing; or if required query parameters are missing or + duplicated. Duplicate keys are rejected explicitly because + silently collapsing them would produce empty remote paths + downstream. + """ + if not url.startswith("ssh://"): + raise ValueError("Only 'ssh://' remote autotuning config URLs are supported") + parsed = urlparse(url) + if parsed.username is None: + raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig") + if parsed.hostname is None: + raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig") + + parsed_query = parse_qs(parsed.query) + duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1) + if duplicates: + raise ValueError(f"Duplicate query parameters in --remoteAutoTuningConfig: {duplicates}") + options = {k: v[0] for k, v in parsed_query.items()} + + required_params = ["remote_exec_path", "remote_lib_path"] + missing = [p for p in required_params if p not in options] + if missing: + raise ValueError( + f"Missing required query parameters in --remoteAutoTuningConfig: {missing}" + ) + + return _RemoteAutotuningConfig( + user=parsed.username, + password=parsed.password or "", + ip=parsed.hostname, + port=parsed.port if parsed.port is not None else 22, + options=options, + bin_path=os.path.dirname(options["remote_exec_path"]), + lib_path=options["remote_lib_path"], + ) + + +def _ensure_remote_autotuning_flags(trtexec_args: list[str], *, log: Any = None) -> list[str]: + """Return ``trtexec_args`` with ``--safe`` and ``--skipInference`` appended if missing. + + Remote autotuning requires both flags. A warning is emitted for each flag + that has to be injected so the user sees that their argv was modified. + """ + result = list(trtexec_args) + for flag in ("--safe", "--skipInference"): + if flag in result: + continue + if log is not None: + log.warning( + f"Remote autotuning requires '{flag}' to be set. Adding it to trtexec arguments." + ) + result.append(flag) + return result + + class TrtExecBenchmark(Benchmark): """TensorRT benchmark using trtexec command-line tool. @@ -172,57 +355,58 @@ def __init__( Example: ['--fp16', '--workspace=4096', '--verbose'] """ super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries) - self.trtexec_args = trtexec_args if trtexec_args is not None else [] + self.trtexec_args = list(trtexec_args) if trtexec_args is not None else [] self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_") self.engine_path = os.path.join(self.temp_dir, "engine.trt") self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx") self.logger.debug(f"Created temporary engine directory: {self.temp_dir}") self.logger.debug(f"Temporary model path: {self.temp_model_path}") - self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms" - self._base_cmd = [ - f"--avgRuns={self.timing_runs}", - f"--iterations={self.timing_runs}", - f"--warmUp={self.warmup_runs}", - "--stronglyTyped", - f"--saveEngine={self.engine_path}", - f"--timingCacheFile={self.timing_cache_file}", - ] - - for plugin_lib in self.plugin_libraries: - plugin_path = Path(plugin_lib).resolve() - if not plugin_path.exists(): - self.logger.warning(f"Plugin library not found: {plugin_path}") - continue - self._base_cmd.append(f"--staticPlugins={plugin_path}") - self.logger.debug(f"Added plugin library: {plugin_path}") - - trtexec_args = self.trtexec_args or [] - has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) + self._base_cmd = _build_base_trtexec_cmd( + timing_runs=self.timing_runs, + warmup_runs=self.warmup_runs, + engine_path=self.engine_path, + timing_cache_file=self.timing_cache_file, + plugin_libraries=self.plugin_libraries, + log=self.logger, + ) - if has_remote_config: + # Defaults for remote-autotuning fields; overwritten when configured. + self.has_remote_config: bool = False + self.remote_ip: str | None = None + self.remote_port: int = 22 + self.remote_user: str = "root" + self.remote_password: str = "" + self.remote_engine_path: str = "trtexec_benchmark_model.trt" + self.remote_bin_path: str = "trtexec" + self.remote_lib_path: str = "" + self.remote_options: dict[str, str] = {} + + remote_value = _extract_remote_config_value(self.trtexec_args, log=self.logger) + if remote_value is not None: + self.has_remote_config = True + if not remote_value: + raise ValueError("Could not parse --remoteAutoTuningConfig argument") + config = _parse_remote_autotuning_url(remote_value) + self.remote_user = config.user + self.remote_password = config.password + self.remote_ip = config.ip + self.remote_port = config.port + self.remote_options = config.options + self.remote_bin_path = config.bin_path + self.remote_lib_path = config.lib_path try: _check_for_trtexec(min_version="10.15") self.logger.debug("TensorRT Python API version >= 10.15 detected") - if "--safe" not in trtexec_args: - self.logger.warning( - "Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments." - ) - self.trtexec_args.append("--safe") - if "--skipInference" not in trtexec_args: - self.logger.warning( - "Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments." - ) - self.trtexec_args.append("--skipInference") except ImportError: self.logger.warning( - "Remote autotuning is not supported with TensorRT version < 10.15. " - "Removing --remoteAutoTuningConfig from trtexec arguments" + "Remote autotuning is not supported with TensorRT version < 10.15." ) - trtexec_args = [ - arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg - ] - self._base_cmd.extend(trtexec_args) + raise + self.trtexec_args = _ensure_remote_autotuning_flags(self.trtexec_args, log=self.logger) + + self.is_safe = "--safe" in self.trtexec_args + self._base_cmd.extend(self.trtexec_args) self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") @@ -287,17 +471,71 @@ def run( self.logger.error(f"trtexec failed with return code {result.returncode}") self.logger.error(f"stderr: {result.stderr}") return float("inf") + latency_pattern = _STD_PATTERN + if self.has_remote_config and self.is_safe: + ssh_pass = [] + if self.remote_password: + ssh_pass.append("sshpass") + ssh_pass.append("-p") + ssh_pass.append(self.remote_password) + # need to push the model to the device and use trtexec_safe to run + scp_cmd = [ + "scp", + f"-P{self.remote_port}", + self.engine_path, + f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}", + ] + scp_cmd = ssh_pass + scp_cmd + result = subprocess.run(scp_cmd, capture_output=True, text=True) # nosec B603 + if result.returncode != 0: + self.logger.error(f"Failed to push engine to remote device: {result.stderr}") + return float("inf") + ld_path = f"LD_LIBRARY_PATH={shlex.quote(self.remote_lib_path)}:$LD_LIBRARY_PATH" + trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec_safe')}" + trtexec_safe_cmd = [ + "ssh", + "-p", + f"{self.remote_port}", + f"{self.remote_user}@{self.remote_ip}", + f"{ld_path} {shlex.quote(trt_path)} --useCudaGraph " + f"--loadEngine={shlex.quote(self.remote_engine_path)}", + ] + trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd + result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + latency_pattern = _SAFE_PATTERN + if result.returncode != 0: + # fallback and try trtexec with "--safe" in case this is a safety proxy target + trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec')}" + trtexec_safe_cmd = [ + "ssh", + "-p", + f"{self.remote_port}", + f"{self.remote_user}@{self.remote_ip}", + f"{ld_path} {shlex.quote(trt_path)} --safe --useCudaGraph " + f"--loadEngine={shlex.quote(self.remote_engine_path)}", + ] + trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd - if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)): - self.logger.warning("Could not parse median latency from trtexec output") - self.logger.debug(f"trtexec stdout:\n{result.stdout}") + result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + latency_pattern = _STD_PATTERN + if result.returncode != 0: + self.logger.error( + f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}" + ) + return float("inf") + if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)): + # this could be due to creating a degenerate onnx file that can't be engine built. + # thus not a hard failure + self.logger.warning(f"trtexec stdout:\n{result.stdout}") + self.logger.error("Could not parse median latency from trtexec output") return float("inf") latency = float(match.group(1)) self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms") return latency - except FileNotFoundError: + except FileNotFoundError as e: self.logger.error( - "'trtexec' binary not found. Please ensure TensorRT is installed and 'trtexec' is in PATH." + f"{e.filename} not found, please ensure system dependencies are installed and in the PATH: \n" + "ssh, scp, sshpass, trtexec" ) return float("inf") except Exception as e: diff --git a/tests/gpu/onnx/quantization/autotune/test_benchmark.py b/tests/gpu/onnx/quantization/autotune/test_benchmark.py index 925d45fffe6..f8ae482647b 100644 --- a/tests/gpu/onnx/quantization/autotune/test_benchmark.py +++ b/tests/gpu/onnx/quantization/autotune/test_benchmark.py @@ -193,7 +193,7 @@ def test_trtexec_run_returns_parsed_latency(trtexec_bench, tmp_path): mock_result = MagicMock() mock_result.returncode = 0 - mock_result.stdout = "[I] Latency: min = 2.50 ms, max = 4.00 ms, median = 3.14 ms" + mock_result.stdout = "[I] GPU Compute Time: min = 2.50 ms, max = 4.00 ms, median = 3.14 ms" mock_result.stderr = "" with patch("subprocess.run", return_value=mock_result): @@ -211,7 +211,7 @@ def test_trtexec_run_returns_inf_when_binary_not_found(trtexec_bench, tmp_path): def test_trtexec_run_accepts_bytes_input(trtexec_bench): mock_result = MagicMock() mock_result.returncode = 0 - mock_result.stdout = "[I] Latency: min = 4.00 ms, max = 6.00 ms, median = 5.00 ms" + mock_result.stdout = "[I] GPU Compute Time: min = 4.00 ms, max = 6.00 ms, median = 5.00 ms" mock_result.stderr = "" with patch("subprocess.run", return_value=mock_result): diff --git a/tests/unit/onnx/quantization/autotune/test_cli_pipeline.py b/tests/unit/onnx/quantization/autotune/test_cli_pipeline.py new file mode 100644 index 00000000000..0e1f1fd8612 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_cli_pipeline.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the autotune CLI pipeline (``run_autotune`` in __main__.py). + +Mocks ``init_benchmark_instance`` and ``region_pattern_autotuning_workflow`` so +the end-to-end argv → exit-code path can be exercised without TensorRT or a +real benchmark. Real ONNX files are used for ``--onnx_path`` / ``--qdq_baseline`` +because ``validate_file_path`` exits the process on missing files. +""" + +from unittest.mock import MagicMock, patch + +import onnx +import pytest +from _test_utils.onnx.quantization.autotune.models import _create_simple_conv_onnx_model + +# The autotune CLI transitively imports ``tensorrt``; in environments where the +# package is locatable but its shared libs are missing, collection fails. Match +# the soft-skip pattern used by ``test_region_inspect.py``. +try: + from modelopt.onnx.quantization.autotune.__main__ import MODE_PRESETS, run_autotune +except ImportError: # pragma: no cover — exercised only in TRT-less envs + pytest.skip("Autotune CLI requires TensorRT", allow_module_level=True) + + +@pytest.fixture +def onnx_model_path(tmp_path): + """A real ONNX file on disk so ``validate_file_path`` succeeds.""" + path = tmp_path / "model.onnx" + onnx.save(_create_simple_conv_onnx_model(), str(path)) + return str(path) + + +@pytest.fixture +def mocked_pipeline(): + """Patch ``init_benchmark_instance`` and the autotuning workflow. + + Yields ``(init_mock, workflow_mock)`` so individual tests can inspect call + args. ``init_mock`` returns a sentinel benchmark by default; tests that need + the failure path can override ``init_mock.return_value = None``. + """ + with ( + patch("modelopt.onnx.quantization.autotune.__main__.init_benchmark_instance") as init_mock, + patch( + "modelopt.onnx.quantization.autotune.__main__.region_pattern_autotuning_workflow" + ) as workflow_mock, + ): + init_mock.return_value = MagicMock(name="benchmark_instance") + yield init_mock, workflow_mock + + +def _run_with_argv(argv): + """Invoke ``run_autotune`` with a patched ``sys.argv``.""" + with patch("sys.argv", ["modelopt.onnx.quantization.autotune", *argv]): + return run_autotune() + + +# --- success paths --- + + +def test_run_autotune_minimal_argv_success(mocked_pipeline, onnx_model_path): + """Minimal argv (just --onnx_path) drives the pipeline to a clean exit.""" + init_mock, workflow_mock = mocked_pipeline + + exit_code = _run_with_argv(["--onnx_path", onnx_model_path]) + + assert exit_code == 0 + init_mock.assert_called_once() + workflow_mock.assert_called_once() + + +def test_run_autotune_default_uses_tensorrt_python_api(mocked_pipeline, onnx_model_path): + """Without --use_trtexec, the benchmark is the TensorRT-Python backend.""" + init_mock, _ = mocked_pipeline + + _run_with_argv(["--onnx_path", onnx_model_path]) + + assert init_mock.call_args.kwargs["use_trtexec"] is False + + +def test_run_autotune_use_trtexec_flag_propagates(mocked_pipeline, onnx_model_path): + """``--use_trtexec`` flips the backend selection passed to init.""" + init_mock, _ = mocked_pipeline + + _run_with_argv(["--onnx_path", onnx_model_path, "--use_trtexec"]) + + assert init_mock.call_args.kwargs["use_trtexec"] is True + + +def test_run_autotune_trtexec_args_split_to_list(mocked_pipeline, onnx_model_path): + """``--trtexec_benchmark_args`` is a quoted string at the CLI but a list at the API.""" + init_mock, _ = mocked_pipeline + + _run_with_argv( + [ + "--onnx_path", + onnx_model_path, + "--use_trtexec", + "--trtexec_benchmark_args", + "--fp16 --workspace=4096 --verbose", + ] + ) + + assert init_mock.call_args.kwargs["trtexec_args"] == [ + "--fp16", + "--workspace=4096", + "--verbose", + ] + + +def test_run_autotune_workflow_receives_model_and_options(mocked_pipeline, onnx_model_path): + """The workflow is called with model path, quant_type, default_dq_dtype, verbose.""" + _, workflow_mock = mocked_pipeline + + _run_with_argv( + [ + "--onnx_path", + onnx_model_path, + "--quant_type", + "fp8", + "--default_dq_dtype", + "float16", + "--verbose", + ] + ) + + kwargs = workflow_mock.call_args.kwargs + assert kwargs["model_or_path"] == onnx_model_path + assert kwargs["quant_type"] == "fp8" + assert kwargs["default_dq_dtype"] == "float16" + assert kwargs["verbose"] is True + + +def test_run_autotune_qdq_baseline_path_propagates(mocked_pipeline, onnx_model_path, tmp_path): + """``--qdq_baseline`` is validated then forwarded to the workflow.""" + baseline_path = tmp_path / "baseline.onnx" + onnx.save(_create_simple_conv_onnx_model(), str(baseline_path)) + _, workflow_mock = mocked_pipeline + + _run_with_argv(["--onnx_path", onnx_model_path, "--qdq_baseline", str(baseline_path)]) + + assert workflow_mock.call_args.kwargs["qdq_baseline_model"] == str(baseline_path) + + +def test_run_autotune_mode_preset_propagates_to_init(mocked_pipeline, onnx_model_path): + """``--mode quick`` overrides warmup/timing runs that init_benchmark_instance receives.""" + init_mock, workflow_mock = mocked_pipeline + preset = MODE_PRESETS["quick"] + + _run_with_argv(["--onnx_path", onnx_model_path, "--mode", "quick"]) + + assert init_mock.call_args.kwargs["warmup_runs"] == preset["warmup_runs"] + assert init_mock.call_args.kwargs["timing_runs"] == preset["timing_runs"] + assert workflow_mock.call_args.kwargs["num_schemes_per_region"] == preset["schemes_per_region"] + + +# --- failure paths --- + + +def test_run_autotune_returns_1_when_init_fails(mocked_pipeline, onnx_model_path): + """Benchmark init returning None short-circuits before the workflow runs.""" + init_mock, workflow_mock = mocked_pipeline + init_mock.return_value = None + + exit_code = _run_with_argv(["--onnx_path", onnx_model_path]) + + assert exit_code == 1 + workflow_mock.assert_not_called() + + +def test_run_autotune_returns_130_on_keyboard_interrupt(mocked_pipeline, onnx_model_path): + """``Ctrl+C`` during the workflow returns the conventional 130 exit code.""" + _, workflow_mock = mocked_pipeline + workflow_mock.side_effect = KeyboardInterrupt + + exit_code = _run_with_argv(["--onnx_path", onnx_model_path]) + + assert exit_code == 130 + + +def test_run_autotune_returns_1_on_workflow_exception(mocked_pipeline, onnx_model_path): + """Any other workflow exception is caught and reported as exit code 1.""" + _, workflow_mock = mocked_pipeline + workflow_mock.side_effect = RuntimeError("boom") + + exit_code = _run_with_argv(["--onnx_path", onnx_model_path]) + + assert exit_code == 1 + + +def test_run_autotune_missing_model_exits(tmp_path): + """``validate_file_path`` calls ``sys.exit(1)`` when --onnx_path does not exist.""" + missing = tmp_path / "does_not_exist.onnx" + + with pytest.raises(SystemExit) as exc_info: + _run_with_argv(["--onnx_path", str(missing)]) + + assert exc_info.value.code == 1 diff --git a/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py new file mode 100644 index 00000000000..af4f82e9704 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py @@ -0,0 +1,728 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ``TrtExecBenchmark`` — the trtexec-CLI benchmarking pipeline. + +These tests fully mock ``subprocess.run`` so the trtexec binary is never +invoked. They cover: + +- Construction of the ``trtexec`` base command (args, plugin libraries). +- ``--remoteAutoTuningConfig`` URL parsing (both ``--key=value`` and + ``--key value`` forms) and the validation errors it raises. +- Auto-injection of ``--safe`` and ``--skipInference`` for remote autotuning. +- The ``run()`` pipeline: standard local invocation; remote scp + ssh + ``trtexec_safe`` invocation; fallback to ``trtexec --safe`` when + ``trtexec_safe`` fails; ``sshpass`` prefix when a password is configured. +- Latency parsing from both ``_STD_PATTERN`` (GPU Compute Time) and + ``_SAFE_PATTERN`` (Average over N runs - GPU latency). +- Error paths: non-zero trtexec returncode, scp failure, missing trtexec + binary, unparseable stdout. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +# Importing ``benchmark`` transitively triggers ``import tensorrt``. In +# environments where the package is locatable but its shared libs are missing +# (e.g. partial CUDA installs), collection would otherwise fail. Mirror the +# soft-skip pattern used by ``test_region_inspect.py``. +try: + from modelopt.onnx.quantization.autotune import benchmark as bm + from modelopt.onnx.quantization.autotune.benchmark import TrtExecBenchmark +except ImportError: # pragma: no cover — exercised only in TRT-less envs + pytest.skip("TrtExecBenchmark requires TensorRT", allow_module_level=True) + + +def _make_proc(returncode=0, stdout="", stderr=""): + """Build a ``subprocess.run``-style result object.""" + return SimpleNamespace(returncode=returncode, stdout=stdout, stderr=stderr) + + +# =========================================================================== +# Standalone helper tests — these exercise the module-level functions in +# isolation, with no TrtExecBenchmark construction or filesystem state. +# =========================================================================== + + +# --- _redact_url_password --- + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ("ssh://alice:s3cret@host", "ssh://alice:******@host"), + ("ssh://alice:s3cret@host:22/path?q=1", "ssh://alice:******@host:22/path?q=1"), + ("ssh://alice@host", "ssh://alice@host"), # no password + ("ssh://host", "ssh://host"), # no userinfo + ("--flag=ssh://u:p@h", "--flag=ssh://u:******@h"), + ("https://user:secret@host", "https://user:******@host"), # any scheme + ], +) +def test_redact_url_password(raw, expected): + assert bm._redact_url_password(raw) == expected + + +# --- _build_base_trtexec_cmd --- + + +def test_build_base_trtexec_cmd_includes_expected_flags(tmp_path): + cmd = bm._build_base_trtexec_cmd( + timing_runs=11, + warmup_runs=3, + engine_path=str(tmp_path / "engine.trt"), + timing_cache_file=str(tmp_path / "cache.bin"), + ) + assert "--avgRuns=11" in cmd + assert "--iterations=11" in cmd + assert "--warmUp=3" in cmd + assert "--stronglyTyped" in cmd + assert any(arg.startswith("--saveEngine=") for arg in cmd) + assert any(arg.startswith("--timingCacheFile=") for arg in cmd) + + +def test_build_base_trtexec_cmd_skips_missing_plugin(tmp_path): + cmd = bm._build_base_trtexec_cmd( + timing_runs=1, + warmup_runs=1, + engine_path=str(tmp_path / "engine.trt"), + timing_cache_file=str(tmp_path / "cache.bin"), + plugin_libraries=[str(tmp_path / "absent.so")], + ) + assert not any("--staticPlugins" in arg for arg in cmd) + + +def test_build_base_trtexec_cmd_adds_existing_plugin(tmp_path): + plugin = tmp_path / "plugin.so" + plugin.write_bytes(b"") + cmd = bm._build_base_trtexec_cmd( + timing_runs=1, + warmup_runs=1, + engine_path=str(tmp_path / "engine.trt"), + timing_cache_file=str(tmp_path / "cache.bin"), + plugin_libraries=[str(plugin)], + ) + assert f"--staticPlugins={plugin.resolve()}" in cmd + + +def test_build_base_trtexec_cmd_warns_via_logger_on_missing_plugin(tmp_path): + log = MagicMock() + bm._build_base_trtexec_cmd( + timing_runs=1, + warmup_runs=1, + engine_path=str(tmp_path / "engine.trt"), + timing_cache_file=str(tmp_path / "cache.bin"), + plugin_libraries=[str(tmp_path / "absent.so")], + log=log, + ) + assert log.warning.called + assert "Plugin library not found" in log.warning.call_args.args[0] + + +# --- _extract_remote_config_value --- + + +def test_extract_remote_config_value_returns_none_when_absent(): + assert bm._extract_remote_config_value(["--fp16", "--workspace=4096"]) is None + + +def test_extract_remote_config_value_equals_form(): + args = ["--fp16", "--remoteAutoTuningConfig=ssh://a:p@h?x=1"] + assert bm._extract_remote_config_value(args) == "ssh://a:p@h?x=1" + + +def test_extract_remote_config_value_space_form(): + args = ["--fp16", "--remoteAutoTuningConfig", "ssh://a:p@h?x=1"] + assert bm._extract_remote_config_value(args) == "ssh://a:p@h?x=1" + + +def test_extract_remote_config_value_empty_value_returned_verbatim(): + """Empty value (``--remoteAutoTuningConfig=``) returned as ``""`` for the caller to flag.""" + assert bm._extract_remote_config_value(["--remoteAutoTuningConfig="]) == "" + + +def test_extract_remote_config_value_rejects_duplicates(): + args = [ + "--remoteAutoTuningConfig=ssh://a:p@h?x=1", + "--remoteAutoTuningConfig=ssh://a:p@h?x=2", + ] + with pytest.raises(ValueError, match="Exactly one"): + bm._extract_remote_config_value(args) + + +def test_extract_remote_config_value_missing_value_at_end_of_argv(): + with pytest.raises(ValueError, match="Missing value"): + bm._extract_remote_config_value(["--remoteAutoTuningConfig"]) + + +def test_extract_remote_config_value_malformed_redacts_password(): + secret = "SuperSecret-2026" + malformed = f"--remoteAutoTuningConfigssh://alice:{secret}@10.0.0.5" + with pytest.raises(ValueError, match="Malformed") as exc_info: + bm._extract_remote_config_value([malformed]) + assert secret not in str(exc_info.value) + assert "alice:******@" in str(exc_info.value) + + +def test_extract_remote_config_value_malformed_redacts_in_debug_log(): + secret = "SuperSecret-2026" + malformed = f"--remoteAutoTuningConfigssh://alice:{secret}@10.0.0.5" + log = MagicMock() + with pytest.raises(ValueError): + bm._extract_remote_config_value([malformed], log=log) + debug_msgs = [c.args[0] for c in log.debug.call_args_list] + assert all(secret not in m for m in debug_msgs) + + +# --- _parse_remote_autotuning_url --- + + +def test_parse_remote_autotuning_url_full(): + cfg = bm._parse_remote_autotuning_url( + "ssh://alice:s3cret@10.0.0.5:2222?" + "remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" + ) + assert cfg.user == "alice" + assert cfg.password == "s3cret" + assert cfg.ip == "10.0.0.5" + assert cfg.port == 2222 + assert cfg.bin_path == "/opt/trt/bin" + assert cfg.lib_path == "/opt/trt/lib" + assert cfg.options == { + "remote_exec_path": "/opt/trt/bin/trtexec", + "remote_lib_path": "/opt/trt/lib", + } + + +def test_parse_remote_autotuning_url_defaults_port_to_22(): + cfg = bm._parse_remote_autotuning_url( + "ssh://alice@host?remote_exec_path=/x/trtexec&remote_lib_path=/y" + ) + assert cfg.port == 22 + + +def test_parse_remote_autotuning_url_empty_password_becomes_empty_string(): + cfg = bm._parse_remote_autotuning_url( + "ssh://alice@host?remote_exec_path=/x/trtexec&remote_lib_path=/y" + ) + assert cfg.password == "" + + +@pytest.mark.parametrize( + ("url", "match"), + [ + ("http://alice@host?remote_exec_path=/x&remote_lib_path=/y", "Only 'ssh://'"), + ("ssh://host?remote_exec_path=/x&remote_lib_path=/y", "remote user"), + ("ssh://alice@host?remote_exec_path=/x", "Missing required query parameters"), + ( + "ssh://alice@host?remote_exec_path=/a&remote_exec_path=/b&remote_lib_path=/y", + "Duplicate query parameters", + ), + ], +) +def test_parse_remote_autotuning_url_validation_errors(url, match): + with pytest.raises(ValueError, match=match): + bm._parse_remote_autotuning_url(url) + + +# --- _ensure_remote_autotuning_flags --- + + +def test_ensure_remote_autotuning_flags_appends_both_when_missing(): + result = bm._ensure_remote_autotuning_flags(["--fp16"]) + assert result == ["--fp16", "--safe", "--skipInference"] + + +def test_ensure_remote_autotuning_flags_preserves_user_supplied(): + result = bm._ensure_remote_autotuning_flags(["--safe", "--fp16"]) + assert result.count("--safe") == 1 + assert "--skipInference" in result + + +def test_ensure_remote_autotuning_flags_returns_new_list(): + original = ["--fp16"] + result = bm._ensure_remote_autotuning_flags(original) + assert result is not original + assert original == ["--fp16"] # input untouched + + +def test_ensure_remote_autotuning_flags_warns_per_injected_flag(): + log = MagicMock() + bm._ensure_remote_autotuning_flags([], log=log) + assert log.warning.call_count == 2 + + +# =========================================================================== +# Integration tests — exercise TrtExecBenchmark.__init__ end-to-end. +# =========================================================================== + + +@pytest.fixture +def bench(tmp_path): + """A plain ``TrtExecBenchmark`` instance with a temp timing cache.""" + return TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + warmup_runs=1, + timing_runs=2, + ) + + +# --- __init__ command construction --- + + +def test_init_builds_base_cmd_with_expected_flags(tmp_path): + """Base command contains the standard trtexec flags derived from ctor args.""" + cache = str(tmp_path / "cache.bin") + b = TrtExecBenchmark(timing_cache_file=cache, warmup_runs=3, timing_runs=7) + + assert "--avgRuns=7" in b._base_cmd + assert "--iterations=7" in b._base_cmd + assert "--warmUp=3" in b._base_cmd + assert "--stronglyTyped" in b._base_cmd + assert f"--timingCacheFile={cache}" in b._base_cmd + assert any(arg.startswith("--saveEngine=") for arg in b._base_cmd) + + +def test_init_extra_trtexec_args_appended(tmp_path): + """User-supplied trtexec args are appended to the base command.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=["--fp16", "--workspace=4096"], + ) + assert b._base_cmd[-2:] == ["--fp16", "--workspace=4096"] + + +def test_init_missing_plugin_library_is_skipped(tmp_path): + """Missing plugin .so paths produce a warning and don't appear in the base cmd.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + plugin_libraries=[str(tmp_path / "absent_plugin.so")], + ) + assert not any("--staticPlugins" in arg for arg in b._base_cmd) + + +def test_init_existing_plugin_library_added(tmp_path): + """Plugin .so paths that exist on disk are added as ``--staticPlugins``.""" + plugin = tmp_path / "fake_plugin.so" + plugin.write_bytes(b"") + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + plugin_libraries=[str(plugin)], + ) + assert any(arg == f"--staticPlugins={plugin.resolve()}" for arg in b._base_cmd) + + +def test_init_safe_flag_sets_is_safe(tmp_path): + """``--safe`` in trtexec_args (without remote config) flips ``is_safe``.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=["--safe"], + ) + assert b.is_safe is True + assert b.has_remote_config is False + + +# --- --remoteAutoTuningConfig parsing --- + + +_REMOTE_URL = ( + "ssh://alice:s3cret@10.0.0.5:2222?" + "remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" +) + + +@pytest.fixture +def trtexec_version_ok(): + """Pretend trtexec >= 10.15 is available so --safe injection succeeds.""" + with patch( + "modelopt.onnx.quantization.autotune.benchmark._check_for_trtexec", + return_value="/usr/local/bin/trtexec", + ) as m: + yield m + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_equals_form_parses(tmp_path): + """``--remoteAutoTuningConfig=ssh://...`` (single arg) is parsed correctly.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + ) + assert b.has_remote_config is True + assert b.remote_user == "alice" + assert b.remote_password == "s3cret" + assert b.remote_ip == "10.0.0.5" + assert b.remote_port == 2222 + assert b.remote_bin_path == "/opt/trt/bin" + assert b.remote_lib_path == "/opt/trt/lib" + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_space_form_parses(tmp_path): + """``--remoteAutoTuningConfig ssh://...`` (two args) is parsed correctly.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=["--remoteAutoTuningConfig", _REMOTE_URL], + ) + assert b.has_remote_config is True + assert b.remote_ip == "10.0.0.5" + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_injects_safe_and_skipinference(tmp_path): + """When remote config is set, ``--safe`` and ``--skipInference`` are added if missing.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + ) + assert "--safe" in b.trtexec_args + assert "--skipInference" in b.trtexec_args + assert b.is_safe is True + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_keeps_user_supplied_safe(tmp_path): + """If user already passed ``--safe``, it's not duplicated.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}", "--safe"], + ) + assert b.trtexec_args.count("--safe") == 1 + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_default_port_is_22(tmp_path): + """A URL with no explicit port falls back to SSH port 22.""" + url = ( + "ssh://alice:s3cret@10.0.0.5?" + "remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" + ) + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={url}"], + ) + assert b.remote_port == 22 + + +@pytest.mark.parametrize( + ("bad_args", "match"), + [ + ( + [ + f"--remoteAutoTuningConfig={_REMOTE_URL}", + f"--remoteAutoTuningConfig={_REMOTE_URL}", + ], + "Exactly one", + ), + (["--remoteAutoTuningConfig"], "Missing value"), + ( + ["--remoteAutoTuningConfig=http://10.0.0.5/?remote_exec_path=x&remote_lib_path=y"], + "Only 'ssh://'", + ), + ( + ["--remoteAutoTuningConfig=ssh://10.0.0.5:22?remote_exec_path=/x&remote_lib_path=/y"], + "remote user", + ), + ( + ["--remoteAutoTuningConfig=ssh://alice@10.0.0.5:22?remote_exec_path=/x"], + "Missing required query parameters", + ), + ( + [ + "--remoteAutoTuningConfig=ssh://alice@10.0.0.5:22?" + "remote_exec_path=/a&remote_exec_path=/b&remote_lib_path=/y" + ], + "Duplicate query parameters", + ), + ], +) +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_config_validation_errors(tmp_path, bad_args, match): + """Malformed ``--remoteAutoTuningConfig`` values raise ``ValueError``.""" + with pytest.raises(ValueError, match=match): + TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=bad_args, + ) + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_malformed_remote_config_redacts_password(tmp_path, caplog): + """Malformed ``--remoteAutoTuningConfig`` ValueError + debug log must not leak the password. + + The else branch in ``__init__`` fires when the flag is mis-separated + (e.g. ``--remoteAutoTuningConfig`` followed by ``ssh://...`` with no ``=``). + Both the raised ``ValueError`` and any debug log line about the arg must + mask the SSH password. + """ + secret = "TopSecret-2026!" + malformed = ( + f"--remoteAutoTuningConfigssh://alice:{secret}@10.0.0.5:22" + "/remote_exec_path=/x&remote_lib_path=/y" + ) + with ( + caplog.at_level("DEBUG", logger="modelopt.onnx"), + pytest.raises(ValueError, match="Malformed --remoteAutoTuningConfig") as exc_info, + ): + TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[malformed], + ) + + # The ValueError message must NOT contain the password — leaking it would + # surface secrets in stack traces and crash reports. + assert secret not in str(exc_info.value) + assert "alice:******@" in str(exc_info.value) + + # The debug log line for the arg must also not contain the password. + for record in caplog.records: + assert secret not in record.getMessage() + + +def test_remote_config_requires_trtexec_10_15(tmp_path): + """When trtexec is too old, remote autotuning surfaces an ImportError.""" + with ( + patch( + "modelopt.onnx.quantization.autotune.benchmark._check_for_trtexec", + side_effect=ImportError("trtexec < 10.15"), + ), + pytest.raises(ImportError, match="trtexec < 10.15"), + ): + TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + ) + + +# --- run() — local trtexec pipeline --- + + +def test_run_invokes_trtexec_with_onnx_path(bench, tmp_path): + """``run(path)`` invokes trtexec with ``--onnx=`` appended.""" + model = tmp_path / "model.onnx" + model.write_bytes(b"") + + proc = _make_proc(stdout="[I] GPU Compute Time: min = 1.0 ms, max = 2.0 ms, median = 1.5 ms") + with patch("subprocess.run", return_value=proc) as run_mock: + latency = bench.run(str(model)) + + assert latency == pytest.approx(1.5) + cmd = run_mock.call_args.args[0] + assert cmd[0] == "trtexec" + assert f"--onnx={model}" in cmd + + +def test_run_writes_bytes_to_temp_file_before_invoking(bench): + """``run(bytes)`` writes the bytes to disk and points trtexec at that file.""" + proc = _make_proc(stdout="[I] GPU Compute Time: min = 1.0 ms, max = 2.0 ms, median = 4.25 ms") + with patch("subprocess.run", return_value=proc) as run_mock: + latency = bench.run(b"\x08onnx-bytes") + + assert latency == pytest.approx(4.25) + cmd = run_mock.call_args.args[0] + assert f"--onnx={bench.temp_model_path}" in cmd + + +def test_run_writes_log_file_when_requested(bench, tmp_path): + """``log_file`` receives stdout, stderr and the constructed command.""" + log_file = tmp_path / "logs" / "trtexec.log" + proc = _make_proc( + stdout="[I] GPU Compute Time: median = 2.0 ms", + stderr="some warning", + ) + with patch("subprocess.run", return_value=proc): + bench.run(str(tmp_path / "model.onnx"), log_file=str(log_file)) + + contents = log_file.read_text() + assert "Command:" in contents + assert "STDOUT:" in contents and "STDERR:" in contents + assert "some warning" in contents + + +def test_run_returns_inf_on_nonzero_returncode(bench, tmp_path): + """Non-zero exit from trtexec yields ``inf`` and short-circuits parsing.""" + proc = _make_proc(returncode=1, stderr="engine build failed", stdout="") + with patch("subprocess.run", return_value=proc): + assert bench.run(str(tmp_path / "m.onnx")) == float("inf") + + +def test_run_returns_inf_when_latency_not_parseable(bench, tmp_path): + """Stdout that doesn't match either pattern yields ``inf``.""" + proc = _make_proc(stdout="all done, no latency line here") + with patch("subprocess.run", return_value=proc): + assert bench.run(str(tmp_path / "m.onnx")) == float("inf") + + +def test_run_returns_inf_when_trtexec_binary_missing(bench, tmp_path): + """A ``FileNotFoundError`` from subprocess.run is mapped to ``inf``.""" + with patch("subprocess.run", side_effect=FileNotFoundError): + assert bench.run(str(tmp_path / "m.onnx")) == float("inf") + + +def test_run_returns_inf_on_unexpected_exception(bench, tmp_path): + """Any non-FileNotFoundError raised mid-pipeline still yields ``inf``.""" + with patch("subprocess.run", side_effect=OSError("disk full")): + assert bench.run(str(tmp_path / "m.onnx")) == float("inf") + + +def test_call_dunder_forwards_to_run(bench, tmp_path): + """Calling the benchmark instance directly invokes ``run`` and returns its result.""" + proc = _make_proc(stdout="[I] GPU Compute Time: median = 9.81 ms") + with patch("subprocess.run", return_value=proc): + latency = bench(str(tmp_path / "m.onnx")) + assert latency == pytest.approx(9.81) + + +def test_del_swallows_cleanup_errors(tmp_path): + """``__del__`` warns but does not raise when ``shutil.rmtree`` errors.""" + b = TrtExecBenchmark(timing_cache_file=str(tmp_path / "cache.bin")) + with patch.object(bm.shutil, "rmtree", side_effect=PermissionError("denied")): + b.__del__() # Must not raise; logs a warning. + + +def test_run_parses_std_pattern(bench, tmp_path): + """``_STD_PATTERN`` matches the real ``GPU Compute Time`` line.""" + stdout = ( + "[01/15/2026-12:00:00] [I] === Performance summary ===\n" + "[I] GPU Compute Time: min = 0.8 ms, max = 1.2 ms, mean = 0.95 ms, " + "median = 0.92 ms, percentile(99%) = 1.18 ms\n" + ) + with patch("subprocess.run", return_value=_make_proc(stdout=stdout)): + assert bench.run(str(tmp_path / "m.onnx")) == pytest.approx(0.92) + + +# --- run() — remote scp + ssh trtexec_safe pipeline --- + + +@pytest.fixture +def remote_bench(tmp_path, trtexec_version_ok): + """A ``TrtExecBenchmark`` configured for remote autotuning. + + Requires ``trtexec_version_ok`` so ``_check_for_trtexec`` is patched during + ``TrtExecBenchmark.__init__``. + """ + del trtexec_version_ok # consumed via pytest fixture injection + return TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + ) + + +def test_remote_run_scp_then_ssh_trtexec_safe(remote_bench, tmp_path): + """The remote path runs trtexec → scp → ssh trtexec_safe, parsing _SAFE_PATTERN.""" + trtexec_proc = _make_proc(stdout="") # build only; --skipInference + scp_proc = _make_proc() + safe_stdout = "[01/15/2026-12:00:00] [I] Average over 10 runs - GPU latency: 3.42 ms\n" + ssh_proc = _make_proc(stdout=safe_stdout) + + with patch("subprocess.run", side_effect=[trtexec_proc, scp_proc, ssh_proc]) as run_mock: + latency = remote_bench.run(str(tmp_path / "m.onnx")) + + assert latency == pytest.approx(3.42) + assert run_mock.call_count == 3 + trtexec_cmd, scp_cmd, ssh_cmd = (c.args[0] for c in run_mock.call_args_list) + assert trtexec_cmd[0] == "trtexec" + # The remote URL in this test carries a password, so scp/ssh are prefixed with sshpass. + assert "scp" in scp_cmd + assert "alice@10.0.0.5:" in scp_cmd[-1] + assert "ssh" in ssh_cmd + assert "alice@10.0.0.5" in ssh_cmd + # The remote command string runs trtexec_safe with the engine path. + remote_cmd_str = ssh_cmd[-1] + assert "trtexec_safe" in remote_cmd_str + assert "--loadEngine=" in remote_cmd_str + + +def test_remote_run_uses_sshpass_when_password_set(remote_bench, tmp_path): + """When the URL carries a password, both scp and ssh are prefixed with ``sshpass``.""" + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + ssh_proc = _make_proc(stdout="[I] Average over 5 runs - GPU latency: 2.0 ms") + + with patch("subprocess.run", side_effect=[trtexec_proc, scp_proc, ssh_proc]) as run_mock: + remote_bench.run(str(tmp_path / "m.onnx")) + + _, scp_cmd, ssh_cmd = (c.args[0] for c in run_mock.call_args_list) + assert scp_cmd[:3] == ["sshpass", "-p", "s3cret"] + assert ssh_cmd[:3] == ["sshpass", "-p", "s3cret"] + + +def test_remote_run_scp_failure_returns_inf(remote_bench, tmp_path): + """If scp fails, the pipeline short-circuits before ssh and returns ``inf``.""" + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc(returncode=1, stderr="permission denied") + + with patch("subprocess.run", side_effect=[trtexec_proc, scp_proc]) as run_mock: + latency = remote_bench.run(str(tmp_path / "m.onnx")) + + assert latency == float("inf") + assert run_mock.call_count == 2 # no ssh call + + +def test_remote_run_falls_back_to_trtexec_safe_flag(remote_bench, tmp_path): + """If ``trtexec_safe`` errors, fall back to ``trtexec --safe`` and parse _STD_PATTERN.""" + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + safe_bin_fail = _make_proc(returncode=127, stderr="trtexec_safe: not found") + fallback_stdout = "[I] GPU Compute Time: median = 5.55 ms" + fallback_proc = _make_proc(stdout=fallback_stdout) + + with patch( + "subprocess.run", + side_effect=[trtexec_proc, scp_proc, safe_bin_fail, fallback_proc], + ) as run_mock: + latency = remote_bench.run(str(tmp_path / "m.onnx")) + + assert latency == pytest.approx(5.55) + fallback_cmd = run_mock.call_args_list[-1].args[0] + remote_cmd_str = fallback_cmd[-1] + assert "trtexec --safe" in remote_cmd_str + assert "trtexec_safe" not in remote_cmd_str + + +def test_remote_run_both_safe_paths_fail_returns_inf(remote_bench, tmp_path): + """If both ``trtexec_safe`` and the ``trtexec --safe`` fallback fail, return ``inf``.""" + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + safe_bin_fail = _make_proc(returncode=127, stderr="not found") + fallback_fail = _make_proc(returncode=1, stderr="also failed") + + with patch( + "subprocess.run", + side_effect=[trtexec_proc, scp_proc, safe_bin_fail, fallback_fail], + ): + assert remote_bench.run(str(tmp_path / "m.onnx")) == float("inf") + + +# --- pattern constants --- + + +def test_std_pattern_matches_gpu_compute_time_line(): + """The std pattern matches a typical ``[I] GPU Compute Time: … median = X ms`` line.""" + import re + + text = "[I] GPU Compute Time: min = 1 ms, max = 2 ms, median = 1.42 ms" + match = re.search(bm._STD_PATTERN, text, re.IGNORECASE) + assert match and match.group(1) == "1.42" + + +def test_safe_pattern_matches_average_over_runs_line(): + """The safe pattern matches the trtexec_safe ``Average over N runs - GPU latency`` line.""" + import re + + text = "[01/15/2026-12:00:00] [I] Average over 10 runs - GPU latency: 7.89 ms" + match = re.search(bm._SAFE_PATTERN, text, re.IGNORECASE) + assert match and match.group(1) == "7.89" From 7718b567b131696fc3d9949e2569eed245861c08 Mon Sep 17 00:00:00 2001 From: dmoodie Date: Thu, 14 May 2026 19:37:49 -0400 Subject: [PATCH 2/4] add timeout for just scp and remote latency profiling Signed-off-by: dmoodie --- .../onnx/quantization/autotune/benchmark.py | 30 ++++- .../autotune/test_trtexec_benchmark.py | 123 ++++++++++++++++++ 2 files changed, 149 insertions(+), 4 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 8e4922c2928..82d160be5e7 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -342,6 +342,7 @@ def __init__( timing_runs: int = 10, plugin_libraries: list[str] | None = None, trtexec_args: list[str] | None = None, + network_timeout_seconds: float = 60 * 5, # 5 minutes ): """Initialize the trtexec benchmark. @@ -353,12 +354,16 @@ def __init__( trtexec_args: Additional command-line arguments to pass to trtexec. These are appended after the standard arguments. Example: ['--fp16', '--workspace=4096', '--verbose'] + network_timeout_seconds: Timeout for network operations in seconds. + Default is 5 minutes. This is the timeout for uploading an engine to the remote device + and running trtexec_safe. If the timeout is exceeded, the benchmark will fail. """ super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries) self.trtexec_args = list(trtexec_args) if trtexec_args is not None else [] self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_") self.engine_path = os.path.join(self.temp_dir, "engine.trt") self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx") + self.network_timeout_seconds = network_timeout_seconds self.logger.debug(f"Created temporary engine directory: {self.temp_dir}") self.logger.debug(f"Temporary model path: {self.temp_model_path}") @@ -448,7 +453,9 @@ def run( cmd = [*self._base_cmd, f"--onnx={model_path}"] full_cmd = ["trtexec", *cmd] self.logger.debug(f"Running: {' '.join(full_cmd)}") - result = _run_trtexec(cmd) + # We do not specify a timeout for engine build since this could take a very long time + # trtexec has its own timeout wrt the remote timing server + result = _run_trtexec(cmd, timeout=None) self._write_log_file( log_file, "\n".join( @@ -486,7 +493,9 @@ def run( f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}", ] scp_cmd = ssh_pass + scp_cmd - result = subprocess.run(scp_cmd, capture_output=True, text=True) # nosec B603 + result = subprocess.run( + scp_cmd, capture_output=True, text=True, timeout=self.network_timeout_seconds + ) # nosec B603 if result.returncode != 0: self.logger.error(f"Failed to push engine to remote device: {result.stderr}") return float("inf") @@ -501,7 +510,12 @@ def run( f"--loadEngine={shlex.quote(self.remote_engine_path)}", ] trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd - result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + result = subprocess.run( + trtexec_safe_cmd, + capture_output=True, + text=True, + timeout=self.network_timeout_seconds, + ) # nosec B603 latency_pattern = _SAFE_PATTERN if result.returncode != 0: # fallback and try trtexec with "--safe" in case this is a safety proxy target @@ -516,7 +530,12 @@ def run( ] trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd - result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + result = subprocess.run( + trtexec_safe_cmd, + capture_output=True, + text=True, + timeout=self.network_timeout_seconds, + ) # nosec B603 latency_pattern = _STD_PATTERN if result.returncode != 0: self.logger.error( @@ -538,6 +557,9 @@ def run( "ssh, scp, sshpass, trtexec" ) return float("inf") + except subprocess.TimeoutExpired as e: + self.logger.error(f"Benchmark timed out: {e}") + return float("inf") except Exception as e: self.logger.error(f"Benchmark failed: {e}") return float("inf") diff --git a/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py index af4f82e9704..1cde34b8fc4 100644 --- a/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py +++ b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py @@ -707,6 +707,129 @@ def test_remote_run_both_safe_paths_fail_returns_inf(remote_bench, tmp_path): assert remote_bench.run(str(tmp_path / "m.onnx")) == float("inf") +# --- network_timeout_seconds --- + + +def test_network_timeout_default_is_five_minutes(tmp_path): + """Default network timeout is 5 minutes (300s).""" + b = TrtExecBenchmark(timing_cache_file=str(tmp_path / "cache.bin")) + assert b.network_timeout_seconds == 300 + + +def test_network_timeout_custom_value_stored(tmp_path): + """User-supplied ``network_timeout_seconds`` is stored on the instance.""" + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + network_timeout_seconds=12.5, + ) + assert b.network_timeout_seconds == 12.5 + + +def test_local_trtexec_call_uses_no_timeout(bench, tmp_path): + """The local engine build path passes ``timeout=None`` (engine builds can be long).""" + proc = _make_proc(stdout="[I] GPU Compute Time: median = 1.0 ms") + with patch("subprocess.run", return_value=proc) as run_mock: + bench.run(str(tmp_path / "m.onnx")) + + # Exactly one subprocess call for the local pipeline; timeout must be None. + assert run_mock.call_count == 1 + assert run_mock.call_args.kwargs.get("timeout") is None + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_remote_pipeline_passes_timeout_to_scp_and_ssh(tmp_path): + """scp, ssh trtexec_safe, and the ssh fallback all receive ``network_timeout_seconds``.""" + timeout = 7.0 + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + network_timeout_seconds=timeout, + ) + + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + safe_fail = _make_proc(returncode=1, stderr="trtexec_safe not found") + fallback_proc = _make_proc(stdout="[I] GPU Compute Time: median = 4.0 ms") + + with patch( + "subprocess.run", + side_effect=[trtexec_proc, scp_proc, safe_fail, fallback_proc], + ) as run_mock: + b.run(str(tmp_path / "m.onnx")) + + # Engine build (call 0) has no timeout; the three remote calls all use it. + assert run_mock.call_args_list[0].kwargs.get("timeout") is None + for idx in (1, 2, 3): # scp, ssh trtexec_safe, ssh fallback + assert run_mock.call_args_list[idx].kwargs.get("timeout") == timeout, ( + f"call {idx} did not receive timeout={timeout}" + ) + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_scp_timeout_returns_inf_and_logs(tmp_path, caplog): + """A ``subprocess.TimeoutExpired`` during the scp step returns ``inf`` and is logged.""" + import subprocess + + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + network_timeout_seconds=1.0, + ) + trtexec_proc = _make_proc(stdout="") + timeout_exc = subprocess.TimeoutExpired(cmd=["scp"], timeout=1.0) + + with ( + caplog.at_level("ERROR", logger="modelopt.onnx"), + patch("subprocess.run", side_effect=[trtexec_proc, timeout_exc]), + ): + assert b.run(str(tmp_path / "m.onnx")) == float("inf") + + assert any("timed out" in r.getMessage().lower() for r in caplog.records) + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_ssh_trtexec_safe_timeout_returns_inf(tmp_path): + """A timeout on the ssh ``trtexec_safe`` call also returns ``inf``.""" + import subprocess + + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + network_timeout_seconds=1.0, + ) + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + timeout_exc = subprocess.TimeoutExpired(cmd=["ssh"], timeout=1.0) + + with patch( + "subprocess.run", + side_effect=[trtexec_proc, scp_proc, timeout_exc], + ): + assert b.run(str(tmp_path / "m.onnx")) == float("inf") + + +@pytest.mark.usefixtures("trtexec_version_ok") +def test_ssh_fallback_timeout_returns_inf(tmp_path): + """A timeout on the ``trtexec --safe`` fallback ssh call returns ``inf``.""" + import subprocess + + b = TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + trtexec_args=[f"--remoteAutoTuningConfig={_REMOTE_URL}"], + network_timeout_seconds=1.0, + ) + trtexec_proc = _make_proc(stdout="") + scp_proc = _make_proc() + safe_fail = _make_proc(returncode=1, stderr="trtexec_safe failed") + timeout_exc = subprocess.TimeoutExpired(cmd=["ssh"], timeout=1.0) + + with patch( + "subprocess.run", + side_effect=[trtexec_proc, scp_proc, safe_fail, timeout_exc], + ): + assert b.run(str(tmp_path / "m.onnx")) == float("inf") + + # --- pattern constants --- From ca186058d34df8d896ac9fe869682f5afae6161e Mon Sep 17 00:00:00 2001 From: dmoodie Date: Sun, 17 May 2026 20:28:58 -0400 Subject: [PATCH 3/4] address PR concerns around argv-smuggling Signed-off-by: dmoodie --- .../onnx/quantization/autotune/benchmark.py | 18 +++++++- .../autotune/test_trtexec_benchmark.py | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 82d160be5e7..88ada52a8ac 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -272,8 +272,9 @@ def _parse_remote_autotuning_url(url: str) -> _RemoteAutotuningConfig: Raises: ValueError: If the scheme is not ``ssh://``; if user or host are - missing; or if required query parameters are missing or - duplicated. Duplicate keys are rejected explicitly because + missing or start with ``-`` (argv-smuggling guard, see + CVE-2017-1000117); or if required query parameters are missing + or duplicated. Duplicate keys are rejected explicitly because silently collapsing them would produce empty remote paths downstream. """ @@ -284,6 +285,19 @@ def _parse_remote_autotuning_url(url: str) -> _RemoteAutotuningConfig: raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig") if parsed.hostname is None: raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig") + # Reject argv-smuggling attempts: a username or host that starts with ``-`` + # would be reinterpreted as a flag by ssh/scp when we build + # ``f"{user}@{host}:..."`` (CVE-2017-1000117 class). ``urlparse`` itself + # does not filter these — e.g. ``ssh://-oProxyCommand=evil@host`` parses + # cleanly into ``username='-oProxyCommand=evil'``. + if parsed.username.startswith("-"): + raise ValueError( + "Remote user in --remoteAutoTuningConfig must not start with '-' (argv-smuggling guard)" + ) + if parsed.hostname.startswith("-"): + raise ValueError( + "Remote host in --remoteAutoTuningConfig must not start with '-' (argv-smuggling guard)" + ) parsed_query = parse_qs(parsed.query) duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1) diff --git a/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py index 1cde34b8fc4..063f30dc6a6 100644 --- a/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py +++ b/tests/unit/onnx/quantization/autotune/test_trtexec_benchmark.py @@ -238,6 +238,51 @@ def test_parse_remote_autotuning_url_validation_errors(url, match): bm._parse_remote_autotuning_url(url) +# --- _parse_remote_autotuning_url argv-smuggling guards (CVE-2017-1000117 class) --- + + +@pytest.mark.parametrize( + "evil_user", + [ + "-oProxyCommand=evil", # specific known argv-smuggling payload + "-l", # single dash + letter (short flag for ssh) + "--debug", # long flag form + ], +) +def test_parse_remote_autotuning_url_rejects_user_starting_with_dash(evil_user): + """A username beginning with ``-`` would be reinterpreted as an ssh/scp flag. + + Without this guard, ``ssh://-oProxyCommand=evil@host/...`` would expand to + ``scp -oProxyCommand=evil@host:...`` and execute the attacker's command. + """ + url = ( + f"ssh://{evil_user}@10.0.0.5?" + "remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" + ) + with pytest.raises(ValueError, match="Remote user.*must not start with '-'"): + bm._parse_remote_autotuning_url(url) + + +def test_parse_remote_autotuning_url_rejects_host_starting_with_dash(): + """A hostname beginning with ``-`` is the same argv-smuggling vector via the host position.""" + # urlparse lowercases hostnames, so capitalization doesn't matter here. + url = ( + "ssh://alice@-oproxycommand?" + "remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" + ) + with pytest.raises(ValueError, match="Remote host.*must not start with '-'"): + bm._parse_remote_autotuning_url(url) + + +def test_parse_remote_autotuning_url_accepts_normal_user_and_host(): + """Regression guard: usernames and hosts not starting with ``-`` are still accepted.""" + cfg = bm._parse_remote_autotuning_url( + "ssh://alice@10.0.0.5?remote_exec_path=/opt/trt/bin/trtexec&remote_lib_path=/opt/trt/lib" + ) + assert cfg.user == "alice" + assert cfg.ip == "10.0.0.5" + + # --- _ensure_remote_autotuning_flags --- From f7d3c024a7263df61378289fcc317d70e2f2bb7a Mon Sep 17 00:00:00 2001 From: dmoodie Date: Sun, 17 May 2026 20:45:32 -0400 Subject: [PATCH 4/4] log stderr on failure Signed-off-by: dmoodie --- modelopt/onnx/quantization/autotune/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 88ada52a8ac..14ac8e933ad 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -553,7 +553,7 @@ def run( latency_pattern = _STD_PATTERN if result.returncode != 0: self.logger.error( - f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}" + f"Failed to run trtexec_safe or trtexec with '--safe'\n{result.stdout}\n{result.stderr}" ) return float("inf") if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)):