From ae0753e5042daeb27731b7844d2d86932e725415 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 28 May 2026 14:46:20 -0700 Subject: [PATCH] chore(rtx): remove TRT-RTX 1.4-era WARs, narrow validator to strided+dilated deconv TensorRT-RTX 1.5 (PR #4297) resolves the upstream cuDNN and JIT issues that the original convolution capability validator and test skips were guarding against. The remaining TRT-RTX limitation in this area is 1D/2D/3D transposed convolutions that combine stride > 1 with dilation > 1, which have no kernel support and crash the build with "Strided & Dilated Deconv are currently not supported". Regular convolutions are unaffected. Changes: 1. py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py - Drop the old WARs in convolution_capability_validator: a. Depthwise conv/deconv BF16 fallback to PyTorch. b. Grouped 3D deconv fallback to PyTorch (any dtype). Both ops now run on TRT directly. - Keep the validator with a single, narrower rule: any transposed convolution (1D/2D/3D) with both stride > 1 and dilation > 1 still falls back to PyTorch. 2. tests/py/dynamo/conversion/test_deconvolution_aten.py - Drop the previous in-test guard for grouped 3D deconv. - Add a shared `_skip_if_rtx_strided_dilated_deconv` helper that mirrors the validator predicate and document why the converter test harness needs it (it bypasses the partitioner, so a validator-rejected op raises UnsupportedOperatorException rather than falling back to PyTorch). - Wire the helper into test_deconv1d/2d/3d. - Add explicit `strided_dilated` parametrize entries to test_deconv1d and test_deconv2d (test_deconv3d's existing combined_params already covers the case). All three skip cleanly on TRT-RTX. 3. tests/py/dynamo/models/test_models.py - Delete test_grouped_deconv3d_fallback; the asserted fallback behavior no longer exists. 4. tests/py/dynamo/models/test_engine_cache.py - Remove the unittest.skipIf(tensorrt_rtx, "Engine caching compilation time assertion is unreliable...") decorator on test_caching_small_model. Refit-engine perf is now reliable on TRT-RTX 1.5. 5. tests/py/dynamo/models/test_weight_stripped_engine.py - Drop the old TRT-RTX timing-based skip on test_dynamo_compile_with_refittable_weight_stripped_engine and fix the underlying test bug it was masking: example_inputs to torch.export.export (batch 100) and arg_inputs to torch_trt.dynamo.compile (batch 128) disagreed, so the engine was built for the export shape and runtime failed when fed the compile-inputs shape. Reuse a single `inputs` list at both call sites so the shapes can't drift. Verified passing on both standard TRT and TRT-RTX nightlies. 6. docsrc/getting_started/tensorrt_rtx.rst - Bump the Windows install-path example from TensorRT-RTX-1.4.0.76 to TensorRT-RTX-1.5.0.114; the Linux example was updated in the 1.5 bump but the Windows block was missed. --- docsrc/getting_started/tensorrt_rtx.rst | 6 +-- .../dynamo/conversion/aten_ops_converters.py | 48 ++++++------------ .../conversion/test_deconvolution_aten.py | 30 ++++++++++- tests/py/dynamo/models/test_engine_cache.py | 4 -- tests/py/dynamo/models/test_models.py | 50 ------------------- .../models/test_weight_stripped_engine.py | 11 ++-- 6 files changed, 50 insertions(+), 99 deletions(-) diff --git a/docsrc/getting_started/tensorrt_rtx.rst b/docsrc/getting_started/tensorrt_rtx.rst index f46f744595..526b09c6d1 100644 --- a/docsrc/getting_started/tensorrt_rtx.rst +++ b/docsrc/getting_started/tensorrt_rtx.rst @@ -165,15 +165,15 @@ Once downloaded: .. code-block:: sh - # If TensorRT-RTX is downloaded in C:\your_local_download_path\TensorRT-RTX-1.4.0.76 - set PATH="%PATH%;C:\your_local_download_path\TensorRT-RTX-1.4.0.76\lib" + # If TensorRT-RTX is downloaded in C:\your_local_download_path\TensorRT-RTX-1.5.0.114 + set PATH="%PATH%;C:\your_local_download_path\TensorRT-RTX-1.5.0.114\lib" echo %PATH% | findstr TensorRT-RTX Install TensorRT-RTX Wheel ~~~~~~~~~~~~~~~~~~~~~~~~~~ `tensorrt_rtx` wheel is published on PyPI. -During `torch_tensorrt_rtx` wheel installation, +During `torch_tensorrt_rtx` wheel installation, it will automatically install the `tensorrt_rtx` wheel. Build Torch-TensorRT with TensorRT-RTX diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index be3482ef84..112b04c187 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -6,6 +6,7 @@ import numpy as np import torch +from tensorrt import ITensor as TRTTensor from torch.fx.node import Argument, Node, Target from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._features import needs_not_tensorrt_rtx @@ -27,8 +28,6 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from tensorrt import ITensor as TRTTensor - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2764,46 +2763,29 @@ def aten_ops_le( def convolution_capability_validator( node: Node, settings: Optional[CompilationSettings] = None ) -> bool: - """Reject unsupported convolution variants on TensorRT-RTX. - - Falls back to PyTorch for: - 1. Depthwise convolutions in BF16 (no kernel support on TRT-RTX). - 2. Grouped 3D deconvolutions (crash on TRT-RTX). + """Reject transposed convolutions (deconvolutions) that combine + stride > 1 with dilation > 1 on TensorRT-RTX — there is no kernel + for this case and the build fails with "Strided & Dilated Deconv + are currently not supported". Applies to 1D / 2D / 3D ConvTranspose; + regular convolutions are unaffected. """ if not ENABLED_FEATURES.tensorrt_rtx: return True - if (input_meta := getattr(node.args[0], "meta", {}).get("tensor_meta")) is None: - return True - - groups = args_bounds_check(node.args, 8) - is_grouped = groups is not None and groups > 1 - is_transposed = bool(args_bounds_check(node.args, 6)) - is_3d = input_meta.shape is not None and len(input_meta.shape) == 5 - is_bf16 = input_meta.dtype == torch.bfloat16 - - # WAR: Grouped 3D deconvolutions crash on TRT-RTX (any dtype). - if is_transposed and is_grouped and is_3d: + if ( + args_bounds_check(node.args, 6) # transposed? + and (stride := args_bounds_check(node.args, 3)) + and (dilation := args_bounds_check(node.args, 5)) + and any(s > 1 for s in stride) + and any(d > 1 for d in dilation) + ): _LOGGER.debug( - "Grouped 3D deconvolution '%s' (groups=%d) is not supported on " - "TensorRT-RTX. Falling back to PyTorch for this layer.", + "Strided + dilated deconvolution '%s' is not supported on " + "TensorRT-RTX. Falling back to PyTorch.", node.name, - groups, ) return False - # WAR: Depthwise convolutions in BF16 are not supported on TRT-RTX. - if is_bf16 and is_grouped: - if ( - weight_meta := getattr(node.args[1], "meta", {}).get("tensor_meta") - ) is not None and groups == weight_meta.shape[0]: - _LOGGER.debug( - "Depthwise convolution '%s' with BF16 is not supported on " - "TensorRT-RTX. Falling back to PyTorch for this layer.", - node.name, - ) - return False - return True diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index 188a538632..8a6d5d1893 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -9,6 +9,27 @@ from .harness import DispatchTestCase +def _any_gt_1(v): + return any(x > 1 for x in v) if isinstance(v, (tuple, list)) else v > 1 + + +def _skip_if_rtx_strided_dilated_deconv(testcase, stride, dilation): + """Mirror convolution_capability_validator: skip on TRT-RTX when a + deconv combines stride > 1 with dilation > 1. + + The converter test harness drives TRTInterpreter directly and skips + the partitioner, so a validator-rejected op raises + UnsupportedOperatorException here instead of falling back to PyTorch + as it would in torch_tensorrt.compile. + """ + if ( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx + and _any_gt_1(stride) + and _any_gt_1(dilation) + ): + testcase.skipTest("Strided + dilated deconv falls back to PyTorch on TRT-RTX") + + class TestDeconvolutionConverter(DispatchTestCase): @parameterized.expand( [ @@ -34,6 +55,7 @@ class TestDeconvolutionConverter(DispatchTestCase): groups=3, output_padding=1, ), + param("strided_dilated", 3, stride=2, padding=2, dilation=2), ] ) def test_deconv1d( @@ -47,6 +69,8 @@ def test_deconv1d( bias=True, output_padding=0, ): + _skip_if_rtx_strided_dilated_deconv(self, stride, dilation) + class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -139,6 +163,7 @@ def forward(self, x): groups=3, output_padding=1, ), + param("strided_dilated", 3, stride=2, padding=2, dilation=2), ] ) def test_deconv2d( @@ -152,6 +177,8 @@ def test_deconv2d( bias=True, output_padding=0, ): + _skip_if_rtx_strided_dilated_deconv(self, stride, dilation) + class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -238,8 +265,7 @@ def test_deconv3d( bias=True, output_padding=0, ): - if groups > 1 and torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx: - self.skipTest("Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX") + _skip_if_rtx_strided_dilated_deconv(self, stride, dilation) class TestModule(torch.nn.Module): def __init__(self): diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 55b92ef73f..7850bb274e 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -640,10 +640,6 @@ def forward(self, c, d): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - "Engine caching compilation time assertion is unreliable with TensorRT-RTX", - ) def test_caching_small_model(self): from torch_tensorrt.dynamo._refit import refit_module_weights diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 351b5326df..ac374a439d 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -600,53 +600,3 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() - - -@pytest.mark.unit -@unittest.skipIf( - not torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "Grouped 3D deconv fallback WAR is TensorRT-RTX specific", -) -def test_grouped_deconv3d_fallback(ir): - """Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX. - - The convolution_capability_validator rejects grouped ConvTranspose3d ops - so that the partitioner keeps them in PyTorch while other ops run on TRT. - """ - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d(3, 16, 3, padding=1) - self.relu = torch.nn.ReLU() - self.deconv = torch.nn.ConvTranspose3d(16, 16, 3, padding=1, groups=16) - - def forward(self, x): - out = self.conv(x) - out = self.relu(out) - out = self.deconv(out) - return out - - model = MyModule().eval().cuda() - input = torch.randn((1, 3, 16, 16, 16), device="cuda") - - compile_spec = { - "inputs": [torchtrt.Input(input.shape, dtype=torch.float32)], - "device": torchtrt.Device("cuda:0"), - "ir": ir, - "pass_through_build_failures": True, - "min_block_size": 1, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)) - - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Grouped 3D deconv fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index d835ac39db..beae4c4ec0 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -268,14 +268,12 @@ def test_engine_caching_saves_weight_stripped_engine(self): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - "Engine caching compilation time assertion is unreliable with TensorRT-RTX", - ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) - exp_program = torch.export.export(pyt_model, args=example_inputs) + # Use the same inputs for both export and compile to avoid a + # static-shape mismatch between the exported program and the engine. + inputs = [torch.randn((100, 3, 224, 224)).to("cuda")] + exp_program = torch.export.export(pyt_model, args=tuple(inputs)) engine_cache_dir = ( "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" @@ -291,7 +289,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # The 2nd and 3rd iterations are to measure the compilation time with engine caching. # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. - inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] results = [] times = [] start = torch.cuda.Event(enable_timing=True)