Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 133 additions & 13 deletions modelopt/onnx/quantization/autotune/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
import importlib.util
import os
import re
import shlex
import shutil
import subprocess # nosec B404
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse

import numpy as np
import torch
Expand Down Expand Up @@ -145,6 +147,13 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None:
self.logger.warning(f"Failed to save logs to {file}: {e}")


safe_pattern = (
r"\[\d{2}/\d{2}/\d{4}-\d{2}:\d{2}:\d{2}\]\s+\[I\]\s+"
r"Average over \d+ runs - GPU latency:\s*([\d.]+)\s*ms"
)
std_pattern = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms"


class TrtExecBenchmark(Benchmark):
"""TensorRT benchmark using trtexec command-line tool.

Expand Down Expand Up @@ -183,7 +192,6 @@ def __init__(
self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx")
self.logger.debug(f"Created temporary engine directory: {self.temp_dir}")
self.logger.debug(f"Temporary model path: {self.temp_model_path}")
self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms"

self._base_cmd = [
self.trtexec_path,
Expand All @@ -204,9 +212,68 @@ def __init__(
self.logger.debug(f"Added plugin library: {plugin_path}")

trtexec_args = self.trtexec_args or []
has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)

if has_remote_config:
self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)
self.remote_ip: str | None = None
self.remote_port: int = 22
self.remote_user: str = "root"
self.remote_password: str = ""
self.remote_engine_path: str = "trtexec_benchmark_model.trt"
self.remote_bin_path: str = "trtexec"

if self.has_remote_config:
remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg]
if len(remote_config) != 1:
raise ValueError("Exactly one --remoteAutoTuningConfig argument is required")
# Parse --remoteAutoTuningConfig argument, which may be given as:
# ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or
# ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...')
#
# The logic: find the arg starting with '--remoteAutoTuningConfig'
# If formatted as '--remoteAutoTuningConfig=...', split off the '='
# Otherwise, grab the next argument.
config_arg_value: str | None = None
for i, arg in enumerate(trtexec_args):
if arg.startswith("--remoteAutoTuningConfig"):
if arg == "--remoteAutoTuningConfig":
# Value should be the next argument
if i + 1 < len(trtexec_args):
config_arg_value = trtexec_args[i + 1]
else:
raise ValueError("Missing value for --remoteAutoTuningConfig")
elif arg.startswith("--remoteAutoTuningConfig="):
config_arg_value = arg.split("=", 1)[1]
else:
raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}")
break
if not config_arg_value:
raise ValueError("Could not parse --remoteAutoTuningConfig argument")
remote_config_str: str = config_arg_value

if not remote_config_str.startswith("ssh://"):
raise ValueError("Only 'ssh://' remote autotuning config URLs are supported")
parsed = urlparse(remote_config_str)
self.remote_user = parsed.username
self.remote_password = parsed.password
self.remote_ip = parsed.hostname
self.remote_port = parsed.port
if self.remote_user is None:
raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig")
if self.remote_ip is None:
raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig")
if self.remote_port is None:
self.remote_port = 22
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Parse query options into a dict
self.remote_options = {
k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed.query).items()
}
required_params = ["remote_exec_path", "remote_lib_path"]
missing = [p for p in required_params if p not in self.remote_options]
if missing:
raise ValueError(
f"Missing required query parameters in --remoteAutoTuningConfig: {missing}"
)
self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"]))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.remote_lib_path = str(self.remote_options["remote_lib_path"])
try:
_check_for_trtexec(min_version="10.15")
self.logger.debug("TensorRT Python API version >= 10.15 detected")
Expand All @@ -215,19 +282,19 @@ def __init__(
"Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments."
)
self.trtexec_args.append("--safe")
self.is_safe = True
if "--skipInference" not in trtexec_args:
self.logger.warning(
"Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments."
)
self.trtexec_args.append("--skipInference")
except ImportError:
except ImportError as e:
self.logger.warning(
"Remote autotuning is not supported with TensorRT version < 10.15. "
"Removing --remoteAutoTuningConfig from trtexec arguments"
"Remote autotuning is not supported with TensorRT version < 10.15."
)
trtexec_args = [
arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg
]
raise e
Comment thread
coderabbitai[bot] marked this conversation as resolved.

self.is_safe = "--safe" in trtexec_args
self._base_cmd.extend(trtexec_args)

self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}")
Expand Down Expand Up @@ -292,10 +359,63 @@ def run(
self.logger.error(f"trtexec failed with return code {result.returncode}")
self.logger.error(f"stderr: {result.stderr}")
return float("inf")
latency_pattern = std_pattern
if self.has_remote_config and self.is_safe:
ssh_pass = []
if self.remote_password:
ssh_pass.append("sshpass")
ssh_pass.append("-p")
ssh_pass.append(self.remote_password)
# need to push the model to the device and use trtexec_safe to run
scp_cmd = [
"scp",
f"-P{self.remote_port}",
self.engine_path,
f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}",
]
scp_cmd = ssh_pass + scp_cmd
result = subprocess.run(scp_cmd, capture_output=True, text=True) # nosec B603
if result.returncode != 0:
self.logger.error(f"Failed to push engine to remote device: {result.stderr}")
return float("inf")
ld_path = f"LD_LIBRARY_PATH={shlex.quote(self.remote_lib_path)}:$LD_LIBRARY_PATH"
trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec_safe')}"
trtexec_safe_cmd = [
"ssh",
"-p",
f"{self.remote_port}",
f"{self.remote_user}@{self.remote_ip}",
f"{ld_path} {shlex.quote(trt_path)} --useCudaGraphs "
f"--loadEngine={shlex.quote(self.remote_engine_path)}",
]
trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd
result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603
latency_pattern = safe_pattern
if result.returncode != 0:
# fallback and try trtexec with "--safe" in case this is a safety proxy target
trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec')}"
trtexec_safe_cmd = [
"ssh",
"-p",
f"{self.remote_port}",
f"{self.remote_user}@{self.remote_ip}",
f"{ld_path} {shlex.quote(trt_path)} --safe --useCudaGraphs "
f"--loadEngine={shlex.quote(self.remote_engine_path)}",
]
trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd

if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)):
self.logger.warning("Could not parse median latency from trtexec output")
self.logger.debug(f"trtexec stdout:\n{result.stdout}")
result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603
latency_pattern = std_pattern
if result.returncode != 0:
self.logger.error(
f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}"
)
return float("inf")
if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)):
# this could be due to creating a degenerate onnx file that can't be engine built.
# thus not a hard failure
self.logger.warning(f"trtexec stdout:\n{result.stdout}")
self.logger.error("Could not parse median latency from trtexec output")
return float("inf")
latency = float(match.group(1))
self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms")
Expand Down