diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index df6dbc877d..db3b09606a 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -30,6 +30,7 @@ import importlib.util import os import re +import shlex import shutil import subprocess # nosec B404 import tempfile @@ -37,6 +38,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Any +from urllib.parse import parse_qs, urlparse import numpy as np import torch @@ -145,6 +147,13 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None: self.logger.warning(f"Failed to save logs to {file}: {e}") +safe_pattern = ( + r"\[\d{2}/\d{2}/\d{4}-\d{2}:\d{2}:\d{2}\]\s+\[I\]\s+" + r"Average over \d+ runs - GPU latency:\s*([\d.]+)\s*ms" +) +std_pattern = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms" + + class TrtExecBenchmark(Benchmark): """TensorRT benchmark using trtexec command-line tool. @@ -183,7 +192,6 @@ def __init__( self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx") self.logger.debug(f"Created temporary engine directory: {self.temp_dir}") self.logger.debug(f"Temporary model path: {self.temp_model_path}") - self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms" self._base_cmd = [ self.trtexec_path, @@ -204,9 +212,68 @@ def __init__( self.logger.debug(f"Added plugin library: {plugin_path}") trtexec_args = self.trtexec_args or [] - has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) - - if has_remote_config: + self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) + self.remote_ip: str | None = None + self.remote_port: int = 22 + self.remote_user: str = "root" + self.remote_password: str = "" + self.remote_engine_path: str = "trtexec_benchmark_model.trt" + self.remote_bin_path: str = "trtexec" + + if self.has_remote_config: + remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg] + if len(remote_config) != 1: + raise ValueError("Exactly one --remoteAutoTuningConfig argument is required") + # Parse --remoteAutoTuningConfig argument, which may be given as: + # ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or + # ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...') + # + # The logic: find the arg starting with '--remoteAutoTuningConfig' + # If formatted as '--remoteAutoTuningConfig=...', split off the '=' + # Otherwise, grab the next argument. + config_arg_value: str | None = None + for i, arg in enumerate(trtexec_args): + if arg.startswith("--remoteAutoTuningConfig"): + if arg == "--remoteAutoTuningConfig": + # Value should be the next argument + if i + 1 < len(trtexec_args): + config_arg_value = trtexec_args[i + 1] + else: + raise ValueError("Missing value for --remoteAutoTuningConfig") + elif arg.startswith("--remoteAutoTuningConfig="): + config_arg_value = arg.split("=", 1)[1] + else: + raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}") + break + if not config_arg_value: + raise ValueError("Could not parse --remoteAutoTuningConfig argument") + remote_config_str: str = config_arg_value + + if not remote_config_str.startswith("ssh://"): + raise ValueError("Only 'ssh://' remote autotuning config URLs are supported") + parsed = urlparse(remote_config_str) + self.remote_user = parsed.username + self.remote_password = parsed.password + self.remote_ip = parsed.hostname + self.remote_port = parsed.port + if self.remote_user is None: + raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig") + if self.remote_ip is None: + raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig") + if self.remote_port is None: + self.remote_port = 22 + # Parse query options into a dict + self.remote_options = { + k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed.query).items() + } + required_params = ["remote_exec_path", "remote_lib_path"] + missing = [p for p in required_params if p not in self.remote_options] + if missing: + raise ValueError( + f"Missing required query parameters in --remoteAutoTuningConfig: {missing}" + ) + self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"])) + self.remote_lib_path = str(self.remote_options["remote_lib_path"]) try: _check_for_trtexec(min_version="10.15") self.logger.debug("TensorRT Python API version >= 10.15 detected") @@ -215,19 +282,19 @@ def __init__( "Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments." ) self.trtexec_args.append("--safe") + self.is_safe = True if "--skipInference" not in trtexec_args: self.logger.warning( "Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments." ) self.trtexec_args.append("--skipInference") - except ImportError: + except ImportError as e: self.logger.warning( - "Remote autotuning is not supported with TensorRT version < 10.15. " - "Removing --remoteAutoTuningConfig from trtexec arguments" + "Remote autotuning is not supported with TensorRT version < 10.15." ) - trtexec_args = [ - arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg - ] + raise e + + self.is_safe = "--safe" in trtexec_args self._base_cmd.extend(trtexec_args) self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") @@ -292,10 +359,63 @@ def run( self.logger.error(f"trtexec failed with return code {result.returncode}") self.logger.error(f"stderr: {result.stderr}") return float("inf") + latency_pattern = std_pattern + if self.has_remote_config and self.is_safe: + ssh_pass = [] + if self.remote_password: + ssh_pass.append("sshpass") + ssh_pass.append("-p") + ssh_pass.append(self.remote_password) + # need to push the model to the device and use trtexec_safe to run + scp_cmd = [ + "scp", + f"-P{self.remote_port}", + self.engine_path, + f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}", + ] + scp_cmd = ssh_pass + scp_cmd + result = subprocess.run(scp_cmd, capture_output=True, text=True) # nosec B603 + if result.returncode != 0: + self.logger.error(f"Failed to push engine to remote device: {result.stderr}") + return float("inf") + ld_path = f"LD_LIBRARY_PATH={shlex.quote(self.remote_lib_path)}:$LD_LIBRARY_PATH" + trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec_safe')}" + trtexec_safe_cmd = [ + "ssh", + "-p", + f"{self.remote_port}", + f"{self.remote_user}@{self.remote_ip}", + f"{ld_path} {shlex.quote(trt_path)} --useCudaGraphs " + f"--loadEngine={shlex.quote(self.remote_engine_path)}", + ] + trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd + result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + latency_pattern = safe_pattern + if result.returncode != 0: + # fallback and try trtexec with "--safe" in case this is a safety proxy target + trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec')}" + trtexec_safe_cmd = [ + "ssh", + "-p", + f"{self.remote_port}", + f"{self.remote_user}@{self.remote_ip}", + f"{ld_path} {shlex.quote(trt_path)} --safe --useCudaGraphs " + f"--loadEngine={shlex.quote(self.remote_engine_path)}", + ] + trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd - if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)): - self.logger.warning("Could not parse median latency from trtexec output") - self.logger.debug(f"trtexec stdout:\n{result.stdout}") + result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 + latency_pattern = std_pattern + if result.returncode != 0: + self.logger.error( + f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}" + ) + return float("inf") + if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)): + # this could be due to creating a degenerate onnx file that can't be engine built. + # thus not a hard failure + self.logger.warning(f"trtexec stdout:\n{result.stdout}") + self.logger.error("Could not parse median latency from trtexec output") return float("inf") latency = float(match.group(1)) self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms")