From 5db92e5caee060f45a9372644b24771cb8ae44e7 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Mon, 25 May 2026 18:31:27 -0700 Subject: [PATCH 1/6] fix(runtime): restore TRT-RTX runtime-cache and dynamic-shapes setup on TRTEngine The Python runtime rework moved the runtime-cache and dynamic-shape kernel specialization machinery off of PythonTorchTensorRTModule, but only the ``runtime_config = None`` stub was carried over to the new TRTEngine. Everything that actually populated those attributes -- the runtime config build-up, the file-locked cache load/save, the dynamic-shape strategy mapping -- was dropped. On non-RTX builds this manifested as ``AttributeError: 'TRTEngine' object has no attribute '_runtime_cache'`` in test_no_runtime_config_for_standard_trt; on RTX builds the runtime cache was effectively dead code. This restores the missing pieces on TRTEngine: - New module-level helper _get_dynamic_shapes_kernel_strategy mapping the setting string ('lazy' / 'eager' / 'none') to the TRT-RTX enum. - __init__ and __setstate__ unconditionally initialize runtime_config and runtime_cache to None so the destructor's save path is safe even on partially-loaded engines. runtime_cache_path is set in _load_serialized_info once compilation settings are decoded, so it is always available regardless of build flavor. - _setup_engine calls _setup_runtime_config on RTX builds and rebuilds the execution context so it picks up the new IRuntimeConfig. _create_execution_context routes through runtime_config when one is present, falling back to the strategy-based path otherwise. - Three new helpers ported from the old PythonTorchTensorRTModule: _setup_runtime_config (builds the IRuntimeConfig, sets allocation strategy + dynamic-shape strategy, creates the runtime cache, hydrates it from disk, binds it to the config), _load_runtime_cache (shared file-locked deserialize) and _save_runtime_cache (exclusive file-locked serialize, getattr-guarded so destructor paths on partially-constructed engines stay safe). - __del__ now calls _save_runtime_cache before reset_captured_graph, with an exception swallow so engine teardown never throws. Tests were also unified on the public attribute names: - test_000_runtime_cache: the four assertions that previously read the underscored ``engine._runtime_config`` / ``engine._runtime_cache`` now use the public names (matching what test_001 was already expecting). This is what makes test_no_runtime_config_for_standard_trt stop raising AttributeError on non-RTX CI. - test_001_dynamic_shapes_kernel_strategy: the non-RTX assertion was a defensive ``getattr(engine, '_runtime_config', None)`` that would have silently passed even with the attribute missing; switched to a direct ``engine.runtime_config`` read so it actually exercises the contract. Verified on an A100 RTX build: all 12 RTX runtime-cache tests and all 6 RTX dynamic-shape strategy tests pass; non-RTX gated tests are skipped on this build as expected. --- .../dynamo/runtime/_TRTEngine.py | 123 +++++++++++++++++- .../dynamo/runtime/test_000_runtime_cache.py | 12 +- ...test_001_dynamic_shapes_kernel_strategy.py | 2 +- 3 files changed, 128 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index c9c6f8a433..31b6e64e30 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -11,6 +11,7 @@ import base64 import copy import logging +import os import pickle import tempfile from contextlib import nullcontext @@ -56,6 +57,16 @@ logger = logging.getLogger(__name__) + +def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT enum. Only meaningful on TensorRT-RTX builds.""" + return { + "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, + "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, + "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, + }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) + + # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -219,7 +230,14 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - self._runtime_config = None + # TensorRT-RTX runtime cache state. ``runtime_cache_path`` is filled in + # by ``_load_serialized_info`` once compilation settings are available; + # ``runtime_config`` and ``runtime_cache`` are populated by + # ``_setup_runtime_config`` on RTX builds. Initialized to ``None`` here + # so the destructor can safely save the cache even if ``_setup_engine`` + # never runs. + self.runtime_config: Any = None + self.runtime_cache: Any = None # NCCL communicator is bound lazily on the first forward pass for # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None @@ -228,6 +246,12 @@ def __init__( self._setup_engine() def __del__(self) -> None: + # Persist the TensorRT-RTX runtime cache before tearing the engine + # down; no-op when ``runtime_cache`` was never populated. + try: + self._save_runtime_cache() + except Exception: + pass self.reset_captured_graph() def __deepcopy__(self, memo: dict[int, Any]) -> "TRTEngine": @@ -279,7 +303,10 @@ def __setstate__(self, state: Any) -> None: torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - self._runtime_config = None + # See ``__init__`` for the rationale: pre-init these so a destructor + # firing on a partially-loaded engine never trips an ``AttributeError``. + self.runtime_config = None + self.runtime_cache = None # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). self._nccl_comm = None @@ -344,6 +371,9 @@ def _load_serialized_info( metadata = self.decode_metadata(self.serialized_metadata) self.settings = metadata.get("settings", CompilationSettings()) + # Path used by ``_load_runtime_cache`` / ``_save_runtime_cache`` on + # TensorRT-RTX. Always set so non-RTX engines also expose it. + self.runtime_cache_path = self.settings.runtime_cache_path self.weight_name_map = metadata.get("weight_name_map") self.symbolic_shape_expressions = metadata.get("inout_symexprs") self.output_tensors_are_unowned = metadata.get( @@ -375,6 +405,18 @@ def close(self) -> None: self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: + # On TensorRT-RTX builds the allocation strategy lives on the + # ``IRuntimeConfig`` (set by ``_setup_runtime_config``), so once the + # runtime config is built we route context creation through it. The + # first call from ``_setup_engine`` precedes ``_setup_runtime_config`` + # and falls through to the strategy-based path below. + if ( + ENABLED_FEATURES.tensorrt_rtx + and getattr(self, "runtime_config", None) is not None + ): + context = self.cuda_engine.create_execution_context(self.runtime_config) + assert context is not None, "Failed to create execution context" + return context strategy = trt.ExecutionContextAllocationStrategy.STATIC if self.resource_allocation_strategy: strategy = trt.ExecutionContextAllocationStrategy.USER_MANAGED @@ -417,6 +459,14 @@ def _setup_engine(self) -> None: ) dist.barrier() + # On TensorRT-RTX, build the IRuntimeConfig (with runtime cache and + # dynamic-shape kernel specialization strategy) and rebuild the + # execution context so it picks them up. The NCCL barrier above runs + # against the initial strategy-based context. + if ENABLED_FEATURES.tensorrt_rtx: + self._setup_runtime_config() + self.context = self._create_execution_context() + if not self.in_binding_names and not self.out_binding_names: input_names: List[str] = [] output_names: List[str] = [] @@ -453,6 +503,75 @@ def _setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() + # --- TensorRT-RTX runtime cache / dynamic shapes strategy --- + + def _setup_runtime_config(self) -> None: + """Build an ``IRuntimeConfig`` with runtime cache and dynamic-shape strategy. + + The runtime cache stores JIT compilation results so kernel/graph + compilation is not repeated across inference runs. The dynamic-shape + kernel specialization strategy controls how the JIT compiles + shape-specialized kernels (``lazy``, ``eager``, or ``none``). + """ + self.runtime_config = self.cuda_engine.create_runtime_config() + alloc_strategy = trt.ExecutionContextAllocationStrategy.STATIC + if self.resource_allocation_strategy: + alloc_strategy = trt.ExecutionContextAllocationStrategy.USER_MANAGED + self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) + self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( + _get_dynamic_shapes_kernel_strategy( + self.settings.dynamic_shapes_kernel_specialization_strategy + ) + ) + logger.info( + "Dynamic shapes kernel specialization strategy: " + f"{self.settings.dynamic_shapes_kernel_specialization_strategy}" + ) + self.runtime_cache = self.runtime_config.create_runtime_cache() + self._load_runtime_cache() + self.runtime_config.set_runtime_cache(self.runtime_cache) + logger.info("TensorRT-RTX runtime cache configured") + + def _load_runtime_cache(self) -> None: + """Load runtime cache from disk if it exists (with a shared file lock).""" + if self.runtime_cache is None: + return + if not os.path.isfile(self.runtime_cache_path): + logger.debug(f"No existing runtime cache at {self.runtime_cache_path}") + return + try: + from filelock import FileLock + + lock = FileLock(self.runtime_cache_path + ".lock") + with lock.acquire(timeout=10): + with open(self.runtime_cache_path, "rb") as f: + data = f.read() + if data: + self.runtime_cache.deserialize(data) + logger.info(f"Loaded runtime cache from {self.runtime_cache_path}") + except Exception as e: + logger.warning(f"Failed to load runtime cache: {e}") + + def _save_runtime_cache(self) -> None: + """Save runtime cache to disk (with an exclusive file lock).""" + if getattr(self, "runtime_cache", None) is None: + return + try: + host_mem = self.runtime_cache.serialize() + if host_mem is None: + return + os.makedirs(os.path.dirname(self.runtime_cache_path), exist_ok=True) + + from filelock import FileLock + + lock = FileLock(self.runtime_cache_path + ".lock") + with lock.acquire(timeout=10): + with open(self.runtime_cache_path, "wb") as f: + f.write(memoryview(host_mem)) + logger.info(f"Saved runtime cache to {self.runtime_cache_path}") + except Exception as e: + logger.warning(f"Failed to save runtime cache: {e}") + # --- distributed / NCCL --- @property diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 14700cbd14..013fd2a92a 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -73,10 +73,10 @@ def test_runtime_config_created(self): engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found in compiled model") self.assertIsNotNone( - engine._runtime_config, "runtime_config should be set for RTX" + engine.runtime_config, "runtime_config should be set for RTX" ) self.assertIsNotNone( - engine._runtime_cache, "runtime_cache should be set for RTX" + engine.runtime_cache, "runtime_cache should be set for RTX" ) def test_context_created_successfully(self): @@ -268,15 +268,15 @@ def test_no_runtime_config_for_standard_trt(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) if engine is not None: - # The TRT-RTX runtime cache machinery is exposed via the private - # ``_runtime_config``/``runtime_cache`` attributes on the Python + # The TRT-RTX runtime cache machinery is exposed via the + # ``runtime_config`` / ``runtime_cache`` attributes on the Python # engine. On non-RTX builds neither should be populated. self.assertIsNone( - engine._runtime_config, + engine.runtime_config, "runtime_config should be None for standard TRT", ) self.assertIsNone( - engine._runtime_cache, + engine.runtime_cache, "runtime_cache should be None for standard TRT", ) diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 359d6bbc9d..11998f9b7a 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -139,7 +139,7 @@ def test_setting_ignored_on_non_rtx(self): engine = _find_python_trt_engine(compiled) if engine is not None: self.assertIsNone( - getattr(engine, "_runtime_config", None), + engine.runtime_config, "runtime_config should be None for standard TRT", ) # Inference should still work From f691f167578a4c3455dafaa1cd2ec5ed813092b6 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Mon, 25 May 2026 20:48:48 -0700 Subject: [PATCH 2/6] feat(runtime): add TRT-RTX native CUDA graph support to TRTEngine TensorRT-RTX has native CUDA graph support via IRuntimeConfig.cuda_graph_strategy, where the JIT compiler handles capture/replay/invalidation internally. This is strictly safer than manual torch.cuda.CUDAGraph capture on RTX because: - lazy-compiled specialized kernels can replace a captured fallback path on the fly, which would invalidate a manually captured graph - runtime allocation / data-dependent shapes can cause cudaStreamBeginCapture to fail outright - the JIT compiler tracks graph staleness (shape changes, pointer changes, kernel readiness) for us Settings + entry points (additive, no behavior change for non-RTX): - _defaults: new CUDA_GRAPH_STRATEGY = "disabled" - _settings.CompilationSettings: new cuda_graph_strategy field (str) - _compiler.compile / cross_compile_for_windows / convert_exported_program_to_serialized_trt_engine: new cuda_graph_strategy kwarg threaded through TRTEngine wiring: - New module-level helper _get_cuda_graph_strategy mapping "disabled" / "whole_graph_capture" -> trt.CudaGraphStrategy. - __init__ / __setstate__ initialize self._rtx_native_cudagraphs = False so the forward path can always read it (including on partially-constructed engines). - _setup_runtime_config also sets self.runtime_config.cuda_graph_strategy from settings on RTX builds (with a paired log line). - _setup_engine latches self._rtx_native_cudagraphs = (RTX and cuda_graph_strategy != "disabled") right after _setup_runtime_config. - New _is_monolithic_capturable(stream): non-RTX returns True (preserves existing behavior); RTX returns False if the IExecutionContext is not stream-capturable or if dynamic-shape strategy is "lazy" (lazy-compiled specialized kernels would invalidate a captured graph). - New _enable_rtx_native_cudagraphs(): rewrites cuda_graph_strategy on the IRuntimeConfig to WHOLE_GRAPH_CAPTURE and rebuilds the execution context. - _execute_standard reads cudagraphs_enabled once at the top; on RTX + cudagraphs enabled + not yet RTX-native, transparently switches to RTX-native (with a warning that tells the user how to set it at compile time). Computes effective_cudagraphs = cudagraphs_enabled and not _rtx_native_cudagraphs and uses it everywhere downstream so the manual torch.cuda.CUDAGraph capture path is bypassed when TRT-RTX owns capture. - Debug log appends " (RTX native)" when _rtx_native_cudagraphs is set. CudaGraphsTorchTensorRTModule wiring (whole-graph mode with mixed TRT + PyTorch): - New _check_monolithic_capturability(stream) iterates the compiled subgraph looking for TorchTensorRTModule whose .engine is a TRTEngine. For each, it calls engine._is_monolithic_capturable and raises RuntimeError if any fails. If an engine has RTX-native cudagraphs on, this turns them off (sets the IRuntimeConfig back to DISABLED and rebuilds the context) so the inner RTX capture cannot interfere with the outer torch.cuda.CUDAGraph capture. - The check fires from forward() just before need_cudagraphs_record allocates the outer torch.cuda.CUDAGraph. Tests: - runtime/test_001_cuda_graph_strategy.py (new): 17 cases covering setup, RTX-native override under SUBGRAPH cudagraphs, _is_monolithic_capturable for each dynamic-shape strategy, context-recreation on _enable_rtx_native_cudagraphs, cudagraphs mode toggle, and a non-RTX gated case. Mirrors the test-helper convention from test_001_dynamic_shapes_kernel_strategy.py (_find_python_trt_engine returns the engine, not the wrapping module). - models/test_cuda_graph_strategy_models.py (new): end-to-end resnet18 and dynamic-batch ConvModel tests for both "disabled" and "whole_graph_capture" strategies. Verified on an A100 RTX build: - test_001_cuda_graph_strategy: 17 passed, 1 skipped (non-RTX gated) - test_000_runtime_cache: 12 passed, 2 skipped (no regression vs. commit 1) - test_001_dynamic_shapes_kernel_strategy: 6 passed, 1 skipped (no regression) --- py/torch_tensorrt/dynamo/_compiler.py | 9 + py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 3 + .../runtime/_CudaGraphsTorchTensorRTModule.py | 54 +++ .../dynamo/runtime/_TRTEngine.py | 105 +++++- .../models/test_cuda_graph_strategy_models.py | 186 +++++++++ .../runtime/test_001_cuda_graph_strategy.py | 357 ++++++++++++++++++ 7 files changed, 705 insertions(+), 10 deletions(-) create mode 100644 tests/py/dynamo/models/test_cuda_graph_strategy_models.py create mode 100644 tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 29c2ed076a..8671dd5860 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -91,6 +91,7 @@ def cross_compile_for_windows( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -171,6 +172,7 @@ def cross_compile_for_windows( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -318,6 +320,7 @@ def cross_compile_for_windows( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -432,6 +435,7 @@ def compile( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -527,6 +531,7 @@ def compile( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -706,6 +711,7 @@ def compile( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1226,6 +1232,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1302,6 +1309,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1458,6 +1466,7 @@ def convert_exported_program_to_serialized_trt_engine( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 784066cc75..00bc39bce5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -70,6 +70,7 @@ DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" USE_PYTHON_RUNTIME = False if platform.system() == "Linux": diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 3fe18e0a0d..694b1a7000 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -95,6 +96,7 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. @@ -150,6 +152,7 @@ class CompilationSettings: dynamic_shapes_kernel_specialization_strategy: str = ( DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 3fae323704..537444a0b9 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -127,6 +127,59 @@ def __del__(self) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable + def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: + """Verify every TRT submodule is safe for monolithic stream capture. + + Whole-graph CUDA graph mode wraps mixed TRT + PyTorch ops in a + single outer ``torch.cuda.CUDAGraph`` capture. On TRT-RTX, each + engine must opt out of RTX-native CUDA graphs (which would + interfere with the outer capture) and must pass the + ``IExecutionContext.is_stream_capturable`` check. Raises + ``RuntimeError`` if any TRT engine is not monolithically + capturable. No-op on non-RTX builds. + """ + from torch_tensorrt._features import ENABLED_FEATURES + + if not ENABLED_FEATURES.tensorrt_rtx: + return + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + from torch_tensorrt.dynamo.runtime._TRTEngine import ( + TRTEngine, + _get_cuda_graph_strategy, + ) + + for name, mod in self.compiled_module.named_modules(): + if not ( + isinstance(mod, TorchTensorRTModule) + and isinstance(mod.engine, TRTEngine) + ): + continue + engine = mod.engine + if not engine._is_monolithic_capturable(stream): + raise RuntimeError( + f"CUDA graph capture failed: TRT submodule '{name}' is " + "not monolithically capturable (lazy kernel " + "specialization or non-capturable stream). Whole-graph " + "CUDA graph mode with mixed TRT + PyTorch ops requires " + "all TRT engines to be capturable. Consider using " + "cuda_graph_strategy='whole_graph_capture' with " + "set_cudagraphs_mode(True) instead of enable_cudagraphs()." + ) + # Disable RTX-native CUDA graphs on this engine so they don't + # interfere with the outer monolithic capture. + if engine._rtx_native_cudagraphs: + engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "disabled" + ) + engine.context = engine._create_execution_context() + engine._rtx_native_cudagraphs = False + logger.info( + f"Disabled RTX-native CUDA graphs for '{name}' " + "(using outer monolithic capture instead)" + ) + def forward( self, *args: Any, **kwargs: Any ) -> torch.Tensor | Tuple[torch.Tensor, ...]: @@ -212,6 +265,7 @@ def forward( with torch.cuda.stream(self._engine_stream): if need_cudagraphs_record: + self._check_monolithic_capturability(self._engine_stream) self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): self._output_buffers = self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 31b6e64e30..7e95d760eb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -67,6 +67,14 @@ def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) +def _get_cuda_graph_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT CudaGraphStrategy enum. Only meaningful on RTX.""" + return { + "disabled": trt.CudaGraphStrategy.DISABLED, + "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) + + # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -238,6 +246,12 @@ def __init__( # never runs. self.runtime_config: Any = None self.runtime_cache: Any = None + # True once an IRuntimeConfig.cuda_graph_strategy other than + # ``"disabled"`` is in effect (either set at compile time or installed + # at runtime by ``_enable_rtx_native_cudagraphs``). When true, + # ``_execute_standard`` must skip manual torch.cuda.CUDAGraph capture + # because TRT-RTX handles it internally. + self._rtx_native_cudagraphs: bool = False # NCCL communicator is bound lazily on the first forward pass for # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None @@ -307,6 +321,7 @@ def __setstate__(self, state: Any) -> None: # firing on a partially-loaded engine never trips an ``AttributeError``. self.runtime_config = None self.runtime_cache = None + self._rtx_native_cudagraphs = False # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). self._nccl_comm = None @@ -459,12 +474,16 @@ def _setup_engine(self) -> None: ) dist.barrier() - # On TensorRT-RTX, build the IRuntimeConfig (with runtime cache and - # dynamic-shape kernel specialization strategy) and rebuild the - # execution context so it picks them up. The NCCL barrier above runs - # against the initial strategy-based context. + # On TensorRT-RTX, build the IRuntimeConfig (with runtime cache, + # dynamic-shape kernel specialization strategy, and CUDA graph + # strategy) and rebuild the execution context so it picks them up. + # The NCCL barrier above runs against the initial strategy-based + # context. if ENABLED_FEATURES.tensorrt_rtx: self._setup_runtime_config() + self._rtx_native_cudagraphs = ( + self.settings.cuda_graph_strategy != "disabled" + ) self.context = self._create_execution_context() if not self.in_binding_names and not self.out_binding_names: @@ -527,6 +546,10 @@ def _setup_runtime_config(self) -> None: "Dynamic shapes kernel specialization strategy: " f"{self.settings.dynamic_shapes_kernel_specialization_strategy}" ) + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + self.settings.cuda_graph_strategy + ) + logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") self.runtime_cache = self.runtime_config.create_runtime_cache() self._load_runtime_cache() self.runtime_config.set_runtime_cache(self.runtime_cache) @@ -572,6 +595,40 @@ def _save_runtime_cache(self) -> None: except Exception as e: logger.warning(f"Failed to save runtime cache: {e}") + def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: + """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. + + Non-RTX builds always return True (existing behavior). On RTX, + capture is unsafe when the TRT-RTX context cannot be stream-captured + (e.g. due to runtime allocation or data-dependent shapes) or when + the dynamic-shape strategy is ``"lazy"`` -- a later lazy-compiled + specialized kernel would invalidate the captured graph. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True + if not self.context.is_stream_capturable(stream.cuda_stream): + return False + if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy": + return False + return True + + def _enable_rtx_native_cudagraphs(self) -> None: + """Switch this engine to TRT-RTX native CUDA graphs. + + Sets the runtime config's ``cuda_graph_strategy`` to + ``WHOLE_GRAPH_CAPTURE`` and rebuilds the execution context so it + picks up the new strategy. No-op on non-RTX or when the runtime + config is not present. + """ + if self.runtime_config is None: + return + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "whole_graph_capture" + ) + self.context = self._create_execution_context() + self._rtx_native_cudagraphs = True + logger.info("Switched to TRT-RTX native CUDA graphs") + # --- distributed / NCCL --- @property @@ -845,6 +902,33 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + # On RTX, manual ``torch.cuda.CUDAGraph`` capture is not safe + # (lazy kernel specialization can invalidate captured graphs and + # runtime allocation can prevent stream capture). If the user + # requested SUBGRAPH cudagraphs without explicitly setting + # ``cuda_graph_strategy="whole_graph_capture"``, transparently + # switch to RTX-native CUDA graphs and warn. + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + if ( + ENABLED_FEATURES.tensorrt_rtx + and cudagraphs_enabled + and not self._rtx_native_cudagraphs + ): + logger.warning( + "Manual CUDA graph capture is not guaranteed to work on " + "TRT-RTX (lazy kernel specialization or non-capturable " + "stream). Switching to TRT-RTX native CUDA graphs. Set " + 'cuda_graph_strategy="whole_graph_capture" at compile ' + "time to avoid this warning." + ) + self._enable_rtx_native_cudagraphs() + + # ``effective_cudagraphs`` is the value the downstream record/replay + # paths should react to. When RTX native is active, TRT-RTX is + # already handling capture/replay internally, so the manual + # ``torch.cuda.CUDAGraph`` machinery must stay quiet. + effective_cudagraphs = cudagraphs_enabled and not self._rtx_native_cudagraphs + # Pick the engine stream BEFORE set_runtime_states so that any # stream-identity change observed this call flips # runtime_states.context_changed in time to trigger same-call @@ -857,7 +941,7 @@ def _execute_standard( can_use_pre_allocated_outputs, need_cudagraphs_reset, ) = self.runtime_states.set_runtime_states( - torch_tensorrt.runtime.get_cudagraphs_mode(), + effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed, ) @@ -872,7 +956,7 @@ def _execute_standard( with self._profile_section("TRTEngine:ProcessInputs"): self.setup_input_tensors( contiguous_inputs, - torch_tensorrt.runtime.get_cudagraphs_mode(), + effective_cudagraphs, need_cudagraphs_record, ) if shape_changed: @@ -899,7 +983,7 @@ def _execute_standard( for o, output_name in enumerate(self.out_binding_names): if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: self.context.set_tensor_address( output_name, self._output_buffers[o].data_ptr() ) @@ -918,7 +1002,7 @@ def _execute_standard( ) self.context.set_device_memory(self._dynamic_workspace.data_ptr()) - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() if self._profile_execution: @@ -947,7 +1031,7 @@ def _execute_standard( ): self.pre_allocated_outputs = self.create_output_tensors() - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: for idx, output in enumerate(outputs): output.copy_(self._output_buffers[idx]) @@ -1049,7 +1133,8 @@ def execute( return self._execute_output_allocator(contiguous_inputs) logger.debug( - f"Using the standard execution runtime mode with cudagraphs={torch_tensorrt.runtime.get_cudagraphs_mode()}." + f"Using the standard execution runtime mode with cudagraphs={torch_tensorrt.runtime.get_cudagraphs_mode()}" + + (" (RTX native)" if self._rtx_native_cudagraphs else "") ) return self._execute_standard(contiguous_inputs) diff --git a/tests/py/dynamo/models/test_cuda_graph_strategy_models.py b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py new file mode 100644 index 0000000000..bce596d15f --- /dev/null +++ b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py @@ -0,0 +1,186 @@ +import unittest + +import torch +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return F.relu(self.conv(x)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyModels(TestCase): + """End-to-end model tests with cuda_graph_strategy.""" + + def _check_cosine_similarity(self, output, ref_output, threshold=0.99): + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > threshold, + f"Cosine similarity {cos_sim.item():.4f} below threshold {threshold}", + ) + + def test_resnet18_whole_graph_capture(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + def test_resnet18_disabled_strategy(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="disabled", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyDynamic(TestCase): + """Tests with dynamic batch sizes and cudagraph mode integration.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_dynamic_batch_whole_graph_capture(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_dynamic_batch_with_subgraph_cudagraphs(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + torchtrt.runtime.set_cudagraphs_mode(True) + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py new file mode 100644 index 0000000000..8b534be362 --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -0,0 +1,357 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(**extra_kwargs): + """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" + model = SimpleModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + } + kwargs.update(extra_kwargs) + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled + + +def _find_python_trt_engine(compiled): + """Walk the compiled graph module and return the Python ``TRTEngine`` instance. + + The C++ and Python runtimes are now both driven through ``TorchTensorRTModule``; + ``use_python_runtime=True`` causes ``module.engine`` to be a Python ``TRTEngine``. + """ + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule) and isinstance(mod.engine, TRTEngine): + return mod.engine + return None + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy requires TensorRT-RTX", +) +class TestCudaGraphStrategySetup(TestCase): + """Tests that cuda_graph_strategy is correctly applied on TRT-RTX.""" + + def test_default_strategy_is_disabled(self): + import tensorrt as trt + + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine, "No Python TRTEngine found") + self.assertIsNotNone( + engine.runtime_config, "runtime_config should be set for RTX" + ) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.DISABLED, + ) + + def test_whole_graph_capture_strategy(self): + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_rtx_native_flag_set(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertTrue(engine._rtx_native_cudagraphs) + + def test_rtx_native_flag_disabled(self): + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertFalse(engine._rtx_native_cudagraphs) + + def test_inference_with_each_strategy(self): + for strategy in ("disabled", "whole_graph_capture"): + with self.subTest(strategy=strategy): + compiled = _compile_simple(cuda_graph_strategy=strategy) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone( + engine.context, + f"Execution context should be created for {strategy}", + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_setting_in_compilation_settings(self): + for strategy in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=strategy) + self.assertEqual(settings.cuda_graph_strategy, strategy) + + def test_default_compilation_settings(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, "disabled") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy integration requires TensorRT-RTX", +) +class TestCudaGraphStrategyWithSubgraphCudagraphs(TestCase): + """Tests integration with set_cudagraphs_mode().""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_rtx_native_bypasses_manual_capture(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference a few times to ensure capture would have happened + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Manual cudagraph should NOT have been recorded (RTX handles it natively) + self.assertFalse( + isinstance(engine.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded when RTX native is active", + ) + + def test_subgraph_mode_always_uses_rtx_native(self): + """Even with cuda_graph_strategy=disabled, SUBGRAPH mode on RTX + should override to RTX-native because manual capture is not safe.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + # Initially, _rtx_native_cudagraphs is False (disabled strategy) + self.assertFalse(engine._rtx_native_cudagraphs) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference -- should trigger override to RTX-native + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Should have been overridden to RTX-native + self.assertTrue( + engine._rtx_native_cudagraphs, + "RTX-native should be enabled automatically in SUBGRAPH mode", + ) + # Manual cudagraph should NOT have been recorded + self.assertFalse( + isinstance(engine.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded on RTX", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Monolithic capturability tests require TensorRT-RTX", +) +class TestMonolithicCapturability(TestCase): + """Tests for _is_monolithic_capturable() and related logic.""" + + def test_lazy_strategy_not_monolithic_capturable(self): + """Lazy kernel specialization makes monolithic capture unsafe.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="lazy", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + self.assertFalse(engine._is_monolithic_capturable(stream)) + + def test_eager_strategy_monolithic_capturable(self): + """Eager strategy with capturable stream should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="eager", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + # is_stream_capturable depends on engine properties. + # With eager strategy, the strategy check passes. + if engine.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(engine._is_monolithic_capturable(stream)) + + def test_none_strategy_monolithic_capturable(self): + """None strategy (always fallback) should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="none", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + if engine.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(engine._is_monolithic_capturable(stream)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Context recreation tests require TensorRT-RTX", +) +class TestContextRecreation(TestCase): + """Tests for _enable_rtx_native_cudagraphs() context recreation.""" + + def test_enable_rtx_native_recreates_context(self): + """Calling _enable_rtx_native_cudagraphs recreates the execution context.""" + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertFalse(engine._rtx_native_cudagraphs) + + old_context_id = id(engine.context) + engine._enable_rtx_native_cudagraphs() + + self.assertTrue(engine._rtx_native_cudagraphs) + self.assertNotEqual( + id(engine.context), + old_context_id, + "Context should be recreated", + ) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_explicit_whole_graph_capture_no_override_needed(self): + """With explicit whole_graph_capture, SUBGRAPH mode should not + need to override (already RTX-native).""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertTrue(engine._rtx_native_cudagraphs) + + old_context_id = id(engine.context) + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) + torchtrt.runtime.set_cudagraphs_mode(False) + + # Context should NOT have been recreated (was already RTX-native) + self.assertEqual( + id(engine.context), + old_context_id, + "Context should not be recreated if already RTX-native", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Cudagraph mode toggle tests require TensorRT-RTX", +) +class TestCudagraphModeToggle(TestCase): + """Tests for toggling cudagraph mode with RTX-native.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_cudagraphs_off_after_rtx_native_override(self): + """After RTX-native override, disabling cudagraphs should still + produce correct results (RTX-native continues transparently).""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) # triggers override + + torchtrt.runtime.set_cudagraphs_mode(False) + + # Should still work -- RTX-native is transparent + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_no_cudagraphs_with_whole_graph_capture(self): + """With cuda_graph_strategy='whole_graph_capture' but no + set_cudagraphs_mode, RTX-native runs transparently.""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertTrue(engine._rtx_native_cudagraphs) + + # No set_cudagraphs_mode(True) -- RTX-native still active transparently + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_toggle_on_off_on(self): + """Toggle cudagraphs on -> off -> on, verify correctness each time.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + inp = torch.randn(2, 3).cuda() + + # Phase 1: on + torchtrt.runtime.set_cudagraphs_mode(True) + out1 = compiled(inp) + self.assertEqual(out1.shape, (2, 3)) + + # Phase 2: off + torchtrt.runtime.set_cudagraphs_mode(False) + out2 = compiled(inp) + self.assertEqual(out2.shape, (2, 3)) + + # Phase 3: on again + torchtrt.runtime.set_cudagraphs_mode(True) + out3 = compiled(inp) + self.assertEqual(out3.shape, (2, 3)) + + +@unittest.skipIf( + ENABLED_FEATURES.tensorrt_rtx, + "This test verifies standard TRT behavior (non-RTX)", +) +class TestCudaGraphStrategyNonRTX(TestCase): + """Tests that the setting is ignored on non-RTX builds.""" + + def test_setting_ignored_on_non_rtx(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + if engine is not None: + self.assertIsNone( + engine.runtime_config, + "runtime_config should be None for standard TRT", + ) + self.assertFalse(engine._rtx_native_cudagraphs) + output = compiled(torch.randn(2, 3).cuda()) + self.assertEqual(output.shape, (2, 3)) + + +if __name__ == "__main__": + run_tests() From 378c1bd159947f4c56db06d8cb0de8f51e15e20d Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 26 May 2026 09:50:25 -0700 Subject: [PATCH 3/6] fixup: address review comments on TRTEngine RTX changes Squash candidates for the two commits on this branch; landed as a separate fixup so each individual review-suggestion is traceable. - Drop the verbose comment blocks on the runtime_config / runtime_cache and _rtx_native_cudagraphs init lines; keep only the destructor / manual-capture rationale. - Remove the self.runtime_cache_path attribute. Inline self.settings.runtime_cache_path in _load_runtime_cache and _save_runtime_cache; the test that previously asserted engine.runtime_cache_path now reads engine.settings.runtime_cache_path. - Refactor _create_execution_context to a clean if/else where the RTX branch asserts runtime_config is not None and uses it directly. Drop the getattr defensive check. - Move _setup_runtime_config to before the first execution-context creation in _setup_engine so we only create the context once. The NCCL barrier still runs against a live context; the RTX runtime config feeds straight into that single context creation. - Add _save_runtime_cache to close(). __del__ now delegates to close(). Drop the try/except in __del__ -- _save_runtime_cache already swallows exceptions internally. - Inline the alloc_strategy decision in _setup_runtime_config as a one-liner ternary. - Drop the getattr-guard in _save_runtime_cache; rely on the runtime_cache = None init in __init__ / __setstate__. - _is_monolithic_capturable now uses any(...) over a tuple of not-capturable conditions instead of multiple if/return False. - Shorten the section header to "# --- TensorRT-RTX ---". - Drop the redundant comment block above the RTX-native override in _execute_standard -- the warning message says the same thing. - Reword the effective_cudagraphs comment to "the manual torch.cuda.CUDAGraph machinery is skipped". Tested on A100 RTX (jobid 2322243, container git_trt_tejaswinp_xaajfdwx): - test_000_runtime_cache: 12 passed, 2 skipped - test_001_dynamic_shapes_kernel_strategy: 6 passed, 1 skipped, 3 subtests passed - test_001_cuda_graph_strategy: 17 passed, 1 skipped, 2 subtests passed --- .../dynamo/runtime/_TRTEngine.py | 129 +++++++----------- .../dynamo/runtime/test_000_runtime_cache.py | 4 +- 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 7e95d760eb..ef7251f014 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -238,19 +238,12 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # TensorRT-RTX runtime cache state. ``runtime_cache_path`` is filled in - # by ``_load_serialized_info`` once compilation settings are available; - # ``runtime_config`` and ``runtime_cache`` are populated by - # ``_setup_runtime_config`` on RTX builds. Initialized to ``None`` here - # so the destructor can safely save the cache even if ``_setup_engine`` - # never runs. + # Initialized to ``None`` here so the destructor can safely save the + # cache even if ``_setup_engine`` never runs. self.runtime_config: Any = None self.runtime_cache: Any = None - # True once an IRuntimeConfig.cuda_graph_strategy other than - # ``"disabled"`` is in effect (either set at compile time or installed - # at runtime by ``_enable_rtx_native_cudagraphs``). When true, - # ``_execute_standard`` must skip manual torch.cuda.CUDAGraph capture - # because TRT-RTX handles it internally. + # When true, ``_execute_standard`` must skip manual torch.cuda.CUDAGraph + # capture because TRT-RTX handles it internally. self._rtx_native_cudagraphs: bool = False # NCCL communicator is bound lazily on the first forward pass for # engines compiled with native multi-device collective layers. @@ -260,13 +253,7 @@ def __init__( self._setup_engine() def __del__(self) -> None: - # Persist the TensorRT-RTX runtime cache before tearing the engine - # down; no-op when ``runtime_cache`` was never populated. - try: - self._save_runtime_cache() - except Exception: - pass - self.reset_captured_graph() + self.close() def __deepcopy__(self, memo: dict[int, Any]) -> "TRTEngine": """Rebuild from serialized layout so ``copy.deepcopy`` skips unpickleable TRT handles.""" @@ -386,9 +373,6 @@ def _load_serialized_info( metadata = self.decode_metadata(self.serialized_metadata) self.settings = metadata.get("settings", CompilationSettings()) - # Path used by ``_load_runtime_cache`` / ``_save_runtime_cache`` on - # TensorRT-RTX. Always set so non-RTX engines also expose it. - self.runtime_cache_path = self.settings.runtime_cache_path self.weight_name_map = metadata.get("weight_name_map") self.symbolic_shape_expressions = metadata.get("inout_symexprs") self.output_tensors_are_unowned = metadata.get( @@ -416,26 +400,21 @@ def get_serialized_metadata(self) -> str: return self.serialized_metadata def close(self) -> None: - """Release CUDA graph resources (called explicitly or via __del__).""" + """Persist the runtime cache and release CUDA graph resources.""" + self._save_runtime_cache() self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: - # On TensorRT-RTX builds the allocation strategy lives on the - # ``IRuntimeConfig`` (set by ``_setup_runtime_config``), so once the - # runtime config is built we route context creation through it. The - # first call from ``_setup_engine`` precedes ``_setup_runtime_config`` - # and falls through to the strategy-based path below. - if ( - ENABLED_FEATURES.tensorrt_rtx - and getattr(self, "runtime_config", None) is not None - ): + if ENABLED_FEATURES.tensorrt_rtx: + assert self.runtime_config is not None context = self.cuda_engine.create_execution_context(self.runtime_config) - assert context is not None, "Failed to create execution context" - return context - strategy = trt.ExecutionContextAllocationStrategy.STATIC - if self.resource_allocation_strategy: - strategy = trt.ExecutionContextAllocationStrategy.USER_MANAGED - context = self.cuda_engine.create_execution_context(strategy) + else: + strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) + context = self.cuda_engine.create_execution_context(strategy) assert context is not None, "Failed to create execution context" return context @@ -450,6 +429,15 @@ def _setup_engine(self) -> None: logger.debug(f"Weight streaming budget set to {budget_bytes}B") self.cuda_engine.weight_streaming_budget_v2 = budget_bytes + # On TensorRT-RTX, build the IRuntimeConfig (runtime cache, + # dynamic-shape kernel specialization strategy, and CUDA graph + # strategy) up front so the one-and-only execution context picks it up. + if ENABLED_FEATURES.tensorrt_rtx: + self._setup_runtime_config() + self._rtx_native_cudagraphs = ( + self.settings.cuda_graph_strategy != "disabled" + ) + self.context = self._create_execution_context() if self._has_nccl_ops: @@ -474,18 +462,6 @@ def _setup_engine(self) -> None: ) dist.barrier() - # On TensorRT-RTX, build the IRuntimeConfig (with runtime cache, - # dynamic-shape kernel specialization strategy, and CUDA graph - # strategy) and rebuild the execution context so it picks them up. - # The NCCL barrier above runs against the initial strategy-based - # context. - if ENABLED_FEATURES.tensorrt_rtx: - self._setup_runtime_config() - self._rtx_native_cudagraphs = ( - self.settings.cuda_graph_strategy != "disabled" - ) - self.context = self._create_execution_context() - if not self.in_binding_names and not self.out_binding_names: input_names: List[str] = [] output_names: List[str] = [] @@ -522,7 +498,7 @@ def _setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() - # --- TensorRT-RTX runtime cache / dynamic shapes strategy --- + # --- TensorRT-RTX --- def _setup_runtime_config(self) -> None: """Build an ``IRuntimeConfig`` with runtime cache and dynamic-shape strategy. @@ -533,9 +509,11 @@ def _setup_runtime_config(self) -> None: shape-specialized kernels (``lazy``, ``eager``, or ``none``). """ self.runtime_config = self.cuda_engine.create_runtime_config() - alloc_strategy = trt.ExecutionContextAllocationStrategy.STATIC - if self.resource_allocation_strategy: - alloc_strategy = trt.ExecutionContextAllocationStrategy.USER_MANAGED + alloc_strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( _get_dynamic_shapes_kernel_strategy( @@ -559,39 +537,41 @@ def _load_runtime_cache(self) -> None: """Load runtime cache from disk if it exists (with a shared file lock).""" if self.runtime_cache is None: return - if not os.path.isfile(self.runtime_cache_path): - logger.debug(f"No existing runtime cache at {self.runtime_cache_path}") + cache_path = self.settings.runtime_cache_path + if not os.path.isfile(cache_path): + logger.debug(f"No existing runtime cache at {cache_path}") return try: from filelock import FileLock - lock = FileLock(self.runtime_cache_path + ".lock") + lock = FileLock(cache_path + ".lock") with lock.acquire(timeout=10): - with open(self.runtime_cache_path, "rb") as f: + with open(cache_path, "rb") as f: data = f.read() if data: self.runtime_cache.deserialize(data) - logger.info(f"Loaded runtime cache from {self.runtime_cache_path}") + logger.info(f"Loaded runtime cache from {cache_path}") except Exception as e: logger.warning(f"Failed to load runtime cache: {e}") def _save_runtime_cache(self) -> None: """Save runtime cache to disk (with an exclusive file lock).""" - if getattr(self, "runtime_cache", None) is None: + if self.runtime_cache is None: return try: host_mem = self.runtime_cache.serialize() if host_mem is None: return - os.makedirs(os.path.dirname(self.runtime_cache_path), exist_ok=True) + cache_path = self.settings.runtime_cache_path + os.makedirs(os.path.dirname(cache_path), exist_ok=True) from filelock import FileLock - lock = FileLock(self.runtime_cache_path + ".lock") + lock = FileLock(cache_path + ".lock") with lock.acquire(timeout=10): - with open(self.runtime_cache_path, "wb") as f: + with open(cache_path, "wb") as f: f.write(memoryview(host_mem)) - logger.info(f"Saved runtime cache to {self.runtime_cache_path}") + logger.info(f"Saved runtime cache to {cache_path}") except Exception as e: logger.warning(f"Failed to save runtime cache: {e}") @@ -606,11 +586,11 @@ def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: """ if not ENABLED_FEATURES.tensorrt_rtx: return True - if not self.context.is_stream_capturable(stream.cuda_stream): - return False - if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy": - return False - return True + not_capturable = ( + not self.context.is_stream_capturable(stream.cuda_stream), + self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy", + ) + return not any(not_capturable) def _enable_rtx_native_cudagraphs(self) -> None: """Switch this engine to TRT-RTX native CUDA graphs. @@ -902,12 +882,6 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: - # On RTX, manual ``torch.cuda.CUDAGraph`` capture is not safe - # (lazy kernel specialization can invalidate captured graphs and - # runtime allocation can prevent stream capture). If the user - # requested SUBGRAPH cudagraphs without explicitly setting - # ``cuda_graph_strategy="whole_graph_capture"``, transparently - # switch to RTX-native CUDA graphs and warn. cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() if ( ENABLED_FEATURES.tensorrt_rtx @@ -923,10 +897,9 @@ def _execute_standard( ) self._enable_rtx_native_cudagraphs() - # ``effective_cudagraphs`` is the value the downstream record/replay - # paths should react to. When RTX native is active, TRT-RTX is - # already handling capture/replay internally, so the manual - # ``torch.cuda.CUDAGraph`` machinery must stay quiet. + # When RTX native is active, TRT-RTX handles capture/replay + # internally so the manual ``torch.cuda.CUDAGraph`` machinery is + # skipped. effective_cudagraphs = cudagraphs_enabled and not self._rtx_native_cudagraphs # Pick the engine stream BEFORE set_runtime_states so that any diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 013fd2a92a..05637a6146 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -90,7 +90,7 @@ def test_context_created_successfully(self): def test_runtime_cache_path_default(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.runtime_cache_path, RUNTIME_CACHE_PATH) + self.assertEqual(engine.settings.runtime_cache_path, RUNTIME_CACHE_PATH) def test_runtime_cache_path_custom(self): cache_dir = tempfile.mkdtemp() @@ -98,7 +98,7 @@ def test_runtime_cache_path_custom(self): custom_path = os.path.join(cache_dir, "my_cache.bin") compiled, _ = _compile_simple(runtime_cache_path=custom_path) engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.runtime_cache_path, custom_path) + self.assertEqual(engine.settings.runtime_cache_path, custom_path) finally: shutil.rmtree(cache_dir, ignore_errors=True) From a683a6fa2b1204f200c0b6bf7a00e5d640717bb5 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 27 May 2026 14:14:42 -0700 Subject: [PATCH 4/6] test: skip enable_cudagraphs weight-streaming tests on TRT-RTX (python runtime) Both python-runtime variants in test_004_weight_streaming.py combine torchtrt.runtime.enable_cudagraphs() (manual whole-graph torch CUDA graph capture) with enable_weight_streaming=True. This combination is fundamentally unsupported on TRT-RTX: weight H2D copies run on a dedicated stream with cross-stream event synchronization, which a single-stream torch.cuda.CUDAGraph capture cannot record. A captured graph would replay against stale or uninitialized weights. The new monolithic-capturability check in the CUDA graph strategy feature already raises RuntimeError for this case at runtime and points at the supported path (cuda_graph_strategy="whole_graph_capture" with set_cudagraphs_mode(True)). The skip avoids the noisy failure during CI sweeps. Skip condition keys off ENABLED_FEATURES.tensorrt_rtx alone, since both tests live in TestWeightStreamingPython and are already python-runtime only. --- tests/py/dynamo/runtime/test_004_weight_streaming.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 5b8b0b94b6..0e82cd4613 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._utils import is_orin from torch_tensorrt.dynamo.utils import prepare_inputs @@ -215,6 +216,11 @@ def test_weight_streaming_multi_rt(self): torch._dynamo.reset() def test_weight_streaming_cudagraphs(self): + if ENABLED_FEATURES.tensorrt_rtx: + self.skipTest( + "Manual whole-graph CUDA graph capture (enable_cudagraphs) is " + "incompatible with weight streaming on TRT-RTX." + ) model = SampleModel().eval().cuda() input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()] exp_program = torch.export.export(model, tuple(input)) @@ -260,6 +266,12 @@ def test_weight_streaming_cudagraphs(self): is_orin(), "There is a bug on Orin platform, skip for now until bug is fixed" ) def test_runtime_state_change(self): + if ENABLED_FEATURES.tensorrt_rtx: + self.skipTest( + "Manual whole-graph CUDA graph capture (enable_cudagraphs) is " + "incompatible with weight streaming on TRT-RTX." + ) + class SampleModel(torch.nn.Module): def __init__(self): super().__init__() From 53470af90e9f4e49b286166cd3ef33c796f297f7 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 27 May 2026 14:15:32 -0700 Subject: [PATCH 5/6] refactor: gate lazy-strategy capture rejection on dynamic-shape inputs _is_monolithic_capturable previously returned False whenever the kernel specialization strategy was "lazy", regardless of whether the engine actually compiles shape-specialized kernels at runtime. For a static-shape engine (no DYNAMIC_DIM on any input binding) the lazy strategy is a no-op -- there are no further specializations possible after build, so a captured CUDA graph cannot be invalidated by mid-replay specialization. Empirically TRT-RTX's own context.is_stream_capturable() returns True for static-shape engines under lazy strategy, confirming the kernel-readiness concern does not apply. Keep the lazy clause but gate it on any(DYNAMIC_DIM in shape for shape in self.input_shapes) so it only fires when shape-specialized kernels can actually appear later. This removes a false-negative that was blocking monolithic capture for static-shape RTX users on the default ("lazy") strategy. TRTEngine did not previously cache self.input_shapes the way it caches self.output_shapes; this commit adds the parallel population in _setup_engine, mirroring the convention used elsewhere. DYNAMIC_DIM is imported from torch_tensorrt.dynamo.utils. --- py/torch_tensorrt/dynamo/runtime/_TRTEngine.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index ef7251f014..5ef2ecb287 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -45,6 +45,7 @@ deserialize_binding_names, parse_device_info, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -487,6 +488,10 @@ def _setup_engine(self) -> None: dtype._from(self.cuda_engine.get_tensor_dtype(output_name)).to(torch.dtype) for output_name in self.out_binding_names ] + self.input_shapes = [ + self.cuda_engine.get_tensor_shape(input_name) + for input_name in self.in_binding_names + ] self.output_shapes = [ self.cuda_engine.get_tensor_shape(output_name) for output_name in self.out_binding_names @@ -578,17 +583,18 @@ def _save_runtime_cache(self) -> None: def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. - Non-RTX builds always return True (existing behavior). On RTX, - capture is unsafe when the TRT-RTX context cannot be stream-captured - (e.g. due to runtime allocation or data-dependent shapes) or when - the dynamic-shape strategy is ``"lazy"`` -- a later lazy-compiled - specialized kernel would invalidate the captured graph. + On RTX, unsafe when the TRT-RTX context is not stream-capturable, or + when ``"lazy"`` kernel specialization can still fire (dynamic inputs). """ if not ENABLED_FEATURES.tensorrt_rtx: return True + has_dynamic_input = any(DYNAMIC_DIM in shape for shape in self.input_shapes) not_capturable = ( not self.context.is_stream_capturable(stream.cuda_stream), - self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy", + ( + self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy" + and has_dynamic_input + ), ) return not any(not_capturable) From b2e37c44c9e40ccddfdba5a8cfd9084ba76529e5 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Mon, 1 Jun 2026 22:45:50 -0700 Subject: [PATCH 6/6] fix(runtime): log effective_cudagraphs (not raw mode) in forward dispatcher When the RTX-native CUDA-graph override is active, the manual torch.cuda.CUDAGraph capture path is bypassed inside _execute_standard, so the debug log should reflect that rather than the raw global cudagraphs mode. Locally compute the same effective_cudagraphs expression used inside _execute_standard and log it. Addresses pytorch/TensorRT#4294 review comment from @cehongwang. --- py/torch_tensorrt/dynamo/runtime/_TRTEngine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 5ef2ecb287..e8496f2d3d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -1111,8 +1111,12 @@ def execute( logger.debug("Using the dynamic allocator runtime mode.") return self._execute_output_allocator(contiguous_inputs) + effective_cudagraphs = ( + torch_tensorrt.runtime.get_cudagraphs_mode() + and not self._rtx_native_cudagraphs + ) logger.debug( - f"Using the standard execution runtime mode with cudagraphs={torch_tensorrt.runtime.get_cudagraphs_mode()}" + f"Using the standard execution runtime mode with cudagraphs={effective_cudagraphs}" + (" (RTX native)" if self._rtx_native_cudagraphs else "") ) return self._execute_standard(contiguous_inputs)