Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docsrc/getting_started/tensorrt_rtx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ Once downloaded:

.. code-block:: sh

# If TensorRT-RTX is downloaded in C:\your_local_download_path\TensorRT-RTX-1.4.0.76
set PATH="%PATH%;C:\your_local_download_path\TensorRT-RTX-1.4.0.76\lib"
# If TensorRT-RTX is downloaded in C:\your_local_download_path\TensorRT-RTX-1.5.0.114
set PATH="%PATH%;C:\your_local_download_path\TensorRT-RTX-1.5.0.114\lib"
echo %PATH% | findstr TensorRT-RTX

Install TensorRT-RTX Wheel
~~~~~~~~~~~~~~~~~~~~~~~~~~

`tensorrt_rtx` wheel is published on PyPI.
During `torch_tensorrt_rtx` wheel installation,
During `torch_tensorrt_rtx` wheel installation,
it will automatically install the `tensorrt_rtx` wheel.

Build Torch-TensorRT with TensorRT-RTX
Expand Down
48 changes: 15 additions & 33 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
Expand All @@ -27,8 +28,6 @@
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM

from tensorrt import ITensor as TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -2764,46 +2763,29 @@ def aten_ops_le(
def convolution_capability_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
"""Reject unsupported convolution variants on TensorRT-RTX.

Falls back to PyTorch for:
1. Depthwise convolutions in BF16 (no kernel support on TRT-RTX).
2. Grouped 3D deconvolutions (crash on TRT-RTX).
"""Reject transposed convolutions (deconvolutions) that combine
stride > 1 with dilation > 1 on TensorRT-RTX — there is no kernel
for this case and the build fails with "Strided & Dilated Deconv
are currently not supported". Applies to 1D / 2D / 3D ConvTranspose;
regular convolutions are unaffected.
"""
if not ENABLED_FEATURES.tensorrt_rtx:
return True

if (input_meta := getattr(node.args[0], "meta", {}).get("tensor_meta")) is None:
return True

groups = args_bounds_check(node.args, 8)
is_grouped = groups is not None and groups > 1
is_transposed = bool(args_bounds_check(node.args, 6))
is_3d = input_meta.shape is not None and len(input_meta.shape) == 5
is_bf16 = input_meta.dtype == torch.bfloat16

# WAR: Grouped 3D deconvolutions crash on TRT-RTX (any dtype).
if is_transposed and is_grouped and is_3d:
if (
args_bounds_check(node.args, 6) # transposed?
and (stride := args_bounds_check(node.args, 3))
and (dilation := args_bounds_check(node.args, 5))
and any(s > 1 for s in stride)
and any(d > 1 for d in dilation)
):
_LOGGER.debug(
"Grouped 3D deconvolution '%s' (groups=%d) is not supported on "
"TensorRT-RTX. Falling back to PyTorch for this layer.",
"Strided + dilated deconvolution '%s' is not supported on "
"TensorRT-RTX. Falling back to PyTorch.",
node.name,
groups,
)
return False

# WAR: Depthwise convolutions in BF16 are not supported on TRT-RTX.
if is_bf16 and is_grouped:
if (
weight_meta := getattr(node.args[1], "meta", {}).get("tensor_meta")
) is not None and groups == weight_meta.shape[0]:
_LOGGER.debug(
"Depthwise convolution '%s' with BF16 is not supported on "
"TensorRT-RTX. Falling back to PyTorch for this layer.",
node.name,
)
return False

return True


Expand Down
30 changes: 28 additions & 2 deletions tests/py/dynamo/conversion/test_deconvolution_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@
from .harness import DispatchTestCase


def _any_gt_1(v):
return any(x > 1 for x in v) if isinstance(v, (tuple, list)) else v > 1


def _skip_if_rtx_strided_dilated_deconv(testcase, stride, dilation):
"""Mirror convolution_capability_validator: skip on TRT-RTX when a
deconv combines stride > 1 with dilation > 1.

The converter test harness drives TRTInterpreter directly and skips
the partitioner, so a validator-rejected op raises
UnsupportedOperatorException here instead of falling back to PyTorch
as it would in torch_tensorrt.compile.
"""
if (
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx
and _any_gt_1(stride)
and _any_gt_1(dilation)
):
testcase.skipTest("Strided + dilated deconv falls back to PyTorch on TRT-RTX")


class TestDeconvolutionConverter(DispatchTestCase):
@parameterized.expand(
[
Expand All @@ -34,6 +55,7 @@ class TestDeconvolutionConverter(DispatchTestCase):
groups=3,
output_padding=1,
),
param("strided_dilated", 3, stride=2, padding=2, dilation=2),
]
)
def test_deconv1d(
Expand All @@ -47,6 +69,8 @@ def test_deconv1d(
bias=True,
output_padding=0,
):
_skip_if_rtx_strided_dilated_deconv(self, stride, dilation)

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -139,6 +163,7 @@ def forward(self, x):
groups=3,
output_padding=1,
),
param("strided_dilated", 3, stride=2, padding=2, dilation=2),
]
)
def test_deconv2d(
Expand All @@ -152,6 +177,8 @@ def test_deconv2d(
bias=True,
output_padding=0,
):
_skip_if_rtx_strided_dilated_deconv(self, stride, dilation)

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -238,8 +265,7 @@ def test_deconv3d(
bias=True,
output_padding=0,
):
if groups > 1 and torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
self.skipTest("Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX")
_skip_if_rtx_strided_dilated_deconv(self, stride, dilation)

class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,6 @@ def forward(self, c, d):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
)
def test_caching_small_model(self):
from torch_tensorrt.dynamo._refit import refit_module_weights

Expand Down
50 changes: 0 additions & 50 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,53 +600,3 @@ def forward(self, x):

# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@unittest.skipIf(
not torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"Grouped 3D deconv fallback WAR is TensorRT-RTX specific",
)
def test_grouped_deconv3d_fallback(ir):
"""Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX.

The convolution_capability_validator rejects grouped ConvTranspose3d ops
so that the partitioner keeps them in PyTorch while other ops run on TRT.
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv3d(3, 16, 3, padding=1)
self.relu = torch.nn.ReLU()
self.deconv = torch.nn.ConvTranspose3d(16, 16, 3, padding=1, groups=16)

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
out = self.deconv(out)
return out

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 16, 16, 16), device="cuda")

compile_spec = {
"inputs": [torchtrt.Input(input.shape, dtype=torch.float32)],
"device": torchtrt.Device("cuda:0"),
"ir": ir,
"pass_through_build_failures": True,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))

assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Grouped 3D deconv fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()
11 changes: 4 additions & 7 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,12 @@ def test_engine_caching_saves_weight_stripped_engine(self):
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
)
def test_dynamo_compile_with_refittable_weight_stripped_engine(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
exp_program = torch.export.export(pyt_model, args=example_inputs)
# Use the same inputs for both export and compile to avoid a
# static-shape mismatch between the exported program and the engine.
inputs = [torch.randn((100, 3, 224, 224)).to("cuda")]
exp_program = torch.export.export(pyt_model, args=tuple(inputs))
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 : Without this change, both TRT standard and TRT-RTX fails this test. I am not too sure whether and with what cadence it is running on the CI currently (my understanding is that this is a L2 test as its not marked with pytest.mark.critical)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I found the error was caught since the PR #4222, but not sure why it was not caught in the previous CI.


engine_cache_dir = (
"/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine"
Expand All @@ -291,7 +289,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
inputs = [torch.rand((128, 3, 224, 224)).to("cuda")]
results = []
times = []
start = torch.cuda.Event(enable_timing=True)
Expand Down
Loading