From 90a21e43ed6f0021d04be2994c8dd561e7558592 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 10:14:30 -0700 Subject: [PATCH 1/7] refactor(runtime): move TRT-RTX runtime controls to runtime context managers Replaces the v2 design that packed three runtime-mode controls (``cuda_graph_strategy``, ``dynamic_shapes_kernel_specialization_strategy``, ``runtime_cache``) into ``CompilationSettings`` and the serialized engine tuple. Per pytorch/TensorRT#4310, these are runtime mode controls -- not engine properties -- and shouldn't pin at compile time or round-trip through serialization. Highlights: * New ``RuntimeSettings`` dataclass on both Python and C++ sides (``py/torch_tensorrt/runtime/_runtime_settings.py``, ``core/runtime/RuntimeSettings.h``). Three fields: ``dynamic_shapes_kernel_specialization_strategy``, ``cuda_graph_strategy``, ``runtime_cache``. The cache field accepts ``None``, a path string (engine creates an implicit handle, saves on ``__del__``, mirrors old ``runtime_cache_path=`` behavior), or a ``RuntimeCacheHandle`` (shared cache, lifecycle owned by the ``runtime_cache()`` CM). * New ``RuntimeCacheHandle`` registered as a torchbind class (``torch.classes.tensorrt.RuntimeCacheHandle``) so the same C++ ``IRuntimeCache`` shared_ptr crosses the Python/C++ boundary. * New per-engine ``update_runtime_settings`` API on both ``TRTEngine`` flavors. Fast-paths on settings equality; eagerly rebuilds ``IRuntimeConfig`` + recreates execution context on diff. * Three new context managers in ``torch_tensorrt.runtime``: ``runtime_config(target_or_targets, **kw)`` (the pool API; also yields the target so ``with runtime_config(model, ...) as m:`` works), ``runtime_cache(target, path)`` (shared cache CM), and the per-knob sugars ``set_cuda_graph_strategy`` / ``set_dynamic_shapes_kernel_strategy``. All three accept a list of modules for multi-target use; the cache CM yields the ``RuntimeCacheHandle`` for inspection or explicit ``save()``. * New ``runtime_settings=`` kwarg on ``compile()``, ``cross_compile_for_windows()``, and ``convert_module()`` so callers can prime the engine with the right values up front. Compile-time hint avoids the enter/exit recreate cost. * ``CompilationSettings`` loses the three fields; the compiler entry points drop the three kwargs. ``SerializedInfoIndex`` drops the four RTX-related slots; ``SERIALIZATION_LEN`` returns to 12. Engines saved with the old 16-slot layout will raise the existing layout-mismatch error on load. * Three existing test files migrated to the new API; new ``tests/py/dynamo/runtime/test_004_runtime_settings.py`` covers the data model, compile-time hint, runtime CM restore semantics, multi-target form, and dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) --- core/runtime/BUILD | 4 + core/runtime/RuntimeSettings.cpp | 46 +++ core/runtime/RuntimeSettings.h | 72 ++++ core/runtime/TRTEngine.cpp | 80 ++-- core/runtime/TRTEngine.h | 34 +- core/runtime/TRTRuntimeConfig.cpp | 247 ++++------- core/runtime/TRTRuntimeConfig.h | 104 ++--- core/runtime/execute_engine.cpp | 2 +- core/runtime/register_jit_hooks.cpp | 28 +- core/runtime/runtime.h | 9 - py/torch_tensorrt/dynamo/_compiler.py | 55 +-- py/torch_tensorrt/dynamo/_defaults.py | 5 - py/torch_tensorrt/dynamo/_settings.py | 11 - .../dynamo/conversion/_conversion.py | 9 +- .../runtime/_CudaGraphsTorchTensorRTModule.py | 11 +- .../dynamo/runtime/_TRTEngine.py | 214 +++++++--- .../dynamo/runtime/_TorchTensorRTModule.py | 118 +++--- .../runtime/_serialized_engine_layout.py | 15 - py/torch_tensorrt/runtime/__init__.py | 7 + .../runtime/_cuda_graph_strategy.py | 26 ++ .../_dynamic_shapes_kernel_strategy.py | 28 ++ py/torch_tensorrt/runtime/_runtime_cache.py | 251 ++++++++++++ py/torch_tensorrt/runtime/_runtime_config.py | 91 ++++ .../runtime/_runtime_settings.py | 98 +++++ .../dynamo/runtime/test_000_runtime_cache.py | 387 +++++------------- .../runtime/test_001_cuda_graph_strategy.py | 346 +++------------- ...test_001_dynamic_shapes_kernel_strategy.py | 148 +++---- .../runtime/test_004_runtime_settings.py | 176 ++++++++ 28 files changed, 1458 insertions(+), 1164 deletions(-) create mode 100644 core/runtime/RuntimeSettings.cpp create mode 100644 core/runtime/RuntimeSettings.h create mode 100644 py/torch_tensorrt/runtime/_cuda_graph_strategy.py create mode 100644 py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py create mode 100644 py/torch_tensorrt/runtime/_runtime_cache.py create mode 100644 py/torch_tensorrt/runtime/_runtime_config.py create mode 100644 py/torch_tensorrt/runtime/_runtime_settings.py create mode 100644 tests/py/dynamo/runtime/test_004_runtime_settings.py diff --git a/core/runtime/BUILD b/core/runtime/BUILD index bb9f779929..5ada005328 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -86,6 +86,7 @@ cc_library( "DeviceList.cpp", "Platform.cpp", "RTDevice.cpp", + "RuntimeSettings.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", "TRTRuntimeConfig.cpp", @@ -96,6 +97,7 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TRTRuntimeConfig.h", @@ -158,6 +160,7 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TensorRTBindingNames.h", @@ -174,6 +177,7 @@ filegroup( srcs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TRTRuntimeConfig.h", diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp new file mode 100644 index 0000000000..437d74f12d --- /dev/null +++ b/core/runtime/RuntimeSettings.cpp @@ -0,0 +1,46 @@ +#include "core/runtime/RuntimeSettings.h" + +#include + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +bool RuntimeSettings::operator==(RuntimeSettings const& other) const noexcept { + // Same handle pointer counts as identical cache; passing the same handle twice + // through update_runtime_settings is a no-op. + return dynamic_shapes_kernel_specialization_strategy == other.dynamic_shapes_kernel_specialization_strategy && + cuda_graph_strategy == other.cuda_graph_strategy && runtime_cache.get() == other.runtime_cache.get(); +} + +RuntimeSettings RuntimeSettings::merge(RuntimeSettings const& override) const { + RuntimeSettings result = *this; + result.dynamic_shapes_kernel_specialization_strategy = override.dynamic_shapes_kernel_specialization_strategy; + result.cuda_graph_strategy = override.cuda_graph_strategy; + if (override.runtime_cache) { + result.runtime_cache = override.runtime_cache; + } + return result; +} + +std::string RuntimeSettings::to_str() const { + std::ostringstream os; + os << "Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_specialization_strategy << std::endl; + os << "CUDA Graph Strategy: " << cuda_graph_strategy << std::endl; + if (runtime_cache) { + auto p = runtime_cache->path(); + os << "Runtime Cache: " << (p.empty() ? "" : p) << std::endl; + } else { + os << "Runtime Cache: " << std::endl; + } + return os.str(); +} + +std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs) { + os << rs.to_str(); + return os; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/RuntimeSettings.h b/core/runtime/RuntimeSettings.h new file mode 100644 index 0000000000..544e4de63f --- /dev/null +++ b/core/runtime/RuntimeSettings.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include + +#include "ATen/core/ivalue.h" +#include "NvInfer.h" +#include "torch/custom_class.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// A passive wrapper around an `IRuntimeCache`. Registered as a torchbind class so +// it can be passed by `c10::intrusive_ptr` across the Python/C++ boundary; the +// same handle gives both runtimes the same underlying `IRuntimeCache*`. +// +// File I/O lives exclusively on the Python side (filelock + serialize/deserialize +// via `trt.IRuntimeCache`). The C++ class is purely a holder; `path` is +// informational and is not consulted by the C++ runtime. +class RuntimeCacheHandle : public torch::CustomClassHolder { + public: + explicit RuntimeCacheHandle(std::string path = "") : path_(std::move(path)) {} + + [[nodiscard]] std::string path() const { + return path_; + } + void set_path(std::string p) { + path_ = std::move(p); + } + +#ifdef TRT_MAJOR_RTX + // The actual TensorRT runtime cache. The first engine that attaches this handle + // materializes it via `IRuntimeConfig::createRuntimeCache()` and writes the + // shared_ptr here; subsequent engines reuse the same pointer for true sharing. + std::shared_ptr cache; +#endif + + private: + std::string path_; +}; + +// Per-engine runtime-only knobs sampled at IExecutionContext creation. +// +// `RuntimeSettings` is a plain struct (not a torchbind class) because we flatten +// it into positional args at the torchbind boundary -- TorchBind can't carry a +// dataclass natively. Equality is value-by-value; the cache field compares +// by pointer identity (same handle -> same cache). +struct RuntimeSettings { + std::string dynamic_shapes_kernel_specialization_strategy = "lazy"; + std::string cuda_graph_strategy = "disabled"; + c10::intrusive_ptr runtime_cache = nullptr; + + bool operator==(RuntimeSettings const& other) const noexcept; + bool operator!=(RuntimeSettings const& other) const noexcept { + return !(*this == other); + } + + // Apply `override`'s non-default fields on top of *this*, returning a new value. + // For non-default detection on the strategy strings we always overlay; the cache + // pointer is overlaid iff `override.runtime_cache` is non-null. + RuntimeSettings merge(RuntimeSettings const& override) const; + + [[nodiscard]] std::string to_str() const; +}; + +std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 873c758854..1110a53b8c 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -97,7 +97,7 @@ TRTEngine::TRTEngine( bool requires_output_allocator, const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - TRTRuntimeConfig runtime_cfg) + RuntimeSettings runtime_settings) : TRTEngine( "deserialized_trt", serialized_engine, @@ -109,7 +109,7 @@ TRTEngine::TRTEngine( requires_output_allocator, serialized_metadata, resource_allocation_strategy, - std::move(runtime_cfg)) {} + std::move(runtime_settings)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -125,7 +125,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic), - make_runtime_config_from_serialized(serialized_info)) { + RuntimeSettings{}) { // Single visible marker that this engine was instantiated through the C++ runtime // entry point (i.e. torch.classes.tensorrt.Engine), distinguishing it from the Python // TRTEngine path. Tests look for this string in captured stderr to verify the @@ -148,8 +148,8 @@ TRTEngine::TRTEngine( bool requires_output_allocator, const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - TRTRuntimeConfig runtime_cfg) { - this->runtime_cfg = std::move(runtime_cfg); + RuntimeSettings runtime_settings) { + this->runtime_settings_ = std::move(runtime_settings); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -273,7 +273,7 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs_size, outputs); } - runtime_cfg.has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); + has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); #ifndef NDEBUG this->enable_profiling(); @@ -294,9 +294,9 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { - // Marked noexcept so safe to invoke from a destructor without - // explicit try/catch; any I/O error is logged internally. - runtime_cfg.save_runtime_cache(); + // Disk persistence for runtime caches is owned by the Python side + // (`RuntimeCacheHandle.save()` invoked from the runtime_cache CM or the engine + // wrapper). The C++ side just lets refcounts drop. trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -464,7 +464,7 @@ std::string TRTEngine::to_str() const { ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; - ss << runtime_cfg.to_str(); + ss << runtime_settings_.to_str(); // clang-format on return ss.str(); } @@ -511,11 +511,7 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]), - std::tuple("has_runtime_cfg", serialized_info[HAS_RUNTIME_CFG_IDX]), - std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), - std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])); + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); } std::vector TRTEngine::serialize() { @@ -541,17 +537,8 @@ std::vector TRTEngine::serialize() { serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; - // rank/world_size are runtime facts (may differ at load time); not serialized. -#ifdef TRT_MAJOR_RTX - serialized_info[HAS_RUNTIME_CFG_IDX] = "1"; -#else - serialized_info[HAS_RUNTIME_CFG_IDX] = "0"; -#endif - serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; - serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( - static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); - serialized_info[CUDA_GRAPH_STRATEGY_IDX] = - std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); + // RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory + // initialization values, not part of the engine's identity. See pytorch/TensorRT#4310. return serialized_info; } @@ -671,30 +658,43 @@ void TRTEngine::release_nccl_comm() { #endif // ENABLE_TRT_NCCL_COLLECTIVES bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { - return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); + return TRTRuntimeConfig::is_monolithic_capturable(runtime_settings_, has_dynamic_inputs, exec_ctx.get(), stream); } void TRTEngine::disable_rtx_native_cudagraphs() { - bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; - runtime_cfg.disable_rtx_native_cudagraphs(name); - if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { - // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx - // so the new strategy takes effect for subsequent enqueueV3 calls. - recreate_execution_context(); +#ifdef TRT_MAJOR_RTX + if (runtime_settings_.cuda_graph_strategy == "disabled") { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << name << " for the remainder of its lifetime."); + RuntimeSettings new_settings = runtime_settings_; + new_settings.cuda_graph_strategy = "disabled"; + update_runtime_settings(std::move(new_settings)); +#endif +} + +void TRTEngine::update_runtime_settings(RuntimeSettings new_settings) { + if (new_settings == runtime_settings_) { + return; } + runtime_settings_ = std::move(new_settings); + // Force the next ensure_initialized to rebuild the IRuntimeConfig with the new + // strategy values + (possibly) the new attached cache handle. + runtime_cfg.reset(); + recreate_execution_context(); + // Existing recreate sites set runtime_states.context_changed for cudagraph + // re-record; do the same here so a settings flip inside an active CM forces + // the next enqueue to re-record any captured graph. + runtime_states.context_changed = true; } void TRTEngine::recreate_execution_context() { - // Flush any kernels the previous execution context may have compiled into the - // runtime cache before creating the replacement. The destructor also saves, but - // doing it here guards against losing compiled kernels across profiling toggles, - // allocator changes, or process kills that happen between allocator changes and - // teardown. No-op on standard TensorRT or when no cache path is configured. - runtime_cfg.save_runtime_cache(); const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; - exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); + exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), runtime_settings_, allocation_strategy); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 47917e9c37..d3307c902f 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -14,6 +14,7 @@ #include "c10/cuda/CUDAStream.h" #include "torch/custom_class.h" +#include "core/runtime/RuntimeSettings.h" #include "core/runtime/TRTEngineProfiler.h" #include "core/runtime/TRTRuntimeConfig.h" #include "core/runtime/TensorRTBindingNames.h" @@ -48,11 +49,7 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple, // requires_native_multidevice - std::tuple, // has_runtime_cfg (gates next three) - std::tuple, // Runtime Cache Path (TRT-RTX) - std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) - std::tuple // CUDA Graph Strategy (TRT-RTX) + std::tuple // requires_native_multidevice >; struct TorchTRTRuntimeStates { @@ -158,7 +155,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); + RuntimeSettings runtime_settings = RuntimeSettings{}); TRTEngine(std::vector serialized_info); @@ -174,7 +171,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); + RuntimeSettings runtime_settings = RuntimeSettings{}); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); @@ -282,17 +279,32 @@ struct TRTEngine : torch::CustomClassHolder { void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); - // Owns the IRuntimeConfig (where supported) and TRT-RTX runtime state. On older TRT - // without IRuntimeConfig (e.g. Jetpack) this just carries strategy values that get - // passed to the legacy createExecutionContext overload. + // Live IRuntimeConfig wrapper. Settings are sourced from `runtime_settings_` at + // every (re)build. On non-RTX or pre-10.11 TRT this is essentially empty. TRTRuntimeConfig runtime_cfg; + // Current user-facing runtime settings. Initialized from the constructor's + // `runtime_settings` param; mutated by `update_runtime_settings`. + RuntimeSettings runtime_settings_; + [[nodiscard]] RuntimeSettings const& runtime_settings() const noexcept { + return runtime_settings_; + } + + // Apply new runtime settings. Fast-paths on equality. On change, rebuilds the + // IRuntimeConfig from the new settings and recreates the execution context. + void update_runtime_settings(RuntimeSettings new_settings); + + // Whether the engine has any input binding with a dynamic dimension. Computed + // once during construction; used by `is_monolithic_capturable`. + bool has_dynamic_inputs = true; + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. bool is_monolithic_capturable(cudaStream_t stream) const; // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when - // an outer stream capture is detected around execute_engine). No-op on non-RTX. + // an outer stream capture is detected around execute_engine). No-op on non-RTX or when + // already disabled. void disable_rtx_native_cudagraphs(); private: diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 6f64a95cbd..9c402bc9c2 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -1,159 +1,98 @@ #include "core/runtime/TRTRuntimeConfig.h" -#include -#include #include #include -#include -#include "core/runtime/runtime.h" +#include "core/runtime/RuntimeSettings.h" #include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace runtime { -// File-local helpers. Kept out of the header because they are only used by this -// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not -// need the enum conversion helpers. namespace { -[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { - switch (s) { - case DynamicShapesKernelStrategy::kLazy: - return "lazy"; - case DynamicShapesKernelStrategy::kEager: - return "eager"; - case DynamicShapesKernelStrategy::kNone: - return "none"; - } - TORCHTRT_CHECK( - false, - "Unexpected DynamicShapesKernelStrategy value: " - << static_cast>(s)); -} - -[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { - switch (s) { - case CudaGraphStrategyOption::kDisabled: - return "disabled"; - case CudaGraphStrategyOption::kWholeGraphCapture: - return "whole_graph_capture"; - } - TORCHTRT_CHECK( - false, - "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); -} - -[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( - std::underlying_type_t v) { - TORCHTRT_CHECK( - v >= 0 && v <= 2, - "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); - return static_cast(v); -} - -[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { - TORCHTRT_CHECK( - v >= 0 && v <= 1, - "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); - return static_cast(v); -} - #ifdef TRT_MAJOR_RTX -// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the -// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is -// itself TensorRT-RTX-only and tests reach this path through the member wrappers. -void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { - TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); - if (!std::filesystem::exists(path)) { - LOG_DEBUG("No existing runtime cache at " << path); - return; +[[nodiscard]] nvinfer1::DynamicShapesKernelSpecializationStrategy to_trt_ds_strategy(std::string const& s) { + if (s == "lazy") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kLAZY; } - std::ifstream f(path, std::ios::binary); - std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); - if (buf.empty()) { - return; + if (s == "eager") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER; } - TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); - LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); + if (s == "none") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kNONE; + } + TORCHTRT_CHECK( + false, "Invalid dynamic_shapes_kernel_specialization_strategy: \"" << s << "\" (expected lazy | eager | none)"); } -void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { - TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); - auto host_mem = make_trt(cache->serialize()); - if (!host_mem || host_mem->size() == 0) { - return; +[[nodiscard]] nvinfer1::CudaGraphStrategy to_trt_cg_strategy(std::string const& s) { + if (s == "disabled") { + return nvinfer1::CudaGraphStrategy::kDISABLED; } - std::filesystem::path fs_path(path); - if (fs_path.has_parent_path()) { - std::filesystem::create_directories(fs_path.parent_path()); + if (s == "whole_graph_capture") { + return nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE; } - std::filesystem::path tmp_path = fs_path; - tmp_path += ".tmp"; - { - std::ofstream out(tmp_path, std::ios::binary); - out.write(reinterpret_cast(host_mem->data()), host_mem->size()); - } - std::filesystem::rename(tmp_path, fs_path); - LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); + TORCHTRT_CHECK(false, "Invalid cuda_graph_strategy: \"" << s << "\" (expected disabled | whole_graph_capture)"); } -#endif // TRT_MAJOR_RTX +#endif } // namespace -void TRTRuntimeConfig::ensure_initialized(TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine) { +void TRTRuntimeConfig::ensure_initialized( + TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine, + TORCHTRT_UNUSED RuntimeSettings const& rs) { #ifdef TRT_HAS_IRUNTIME_CONFIG - if (config) { - return; + if (!config) { + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); } - TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); - config = make_trt(cuda_engine->createRuntimeConfig()); - TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); #ifdef TRT_MAJOR_RTX - // Runtime cache -- TRT-RTX only. - if (!runtime_cache_path.empty()) { - runtime_cache = make_trt(config->createRuntimeCache()); - if (runtime_cache.get() == nullptr) { - LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + // Runtime cache: ONLY attach when the caller provided an external + // RuntimeCacheHandle. The Python TRTEngine side creates an implicit + // handle from a path string and passes it in via the handle; without + // an explicit user opt-in we leave the IRuntimeConfig cache-less. + if (rs.runtime_cache) { + if (!rs.runtime_cache->cache) { + rs.runtime_cache->cache = make_trt(config->createRuntimeCache()); + TORCHTRT_CHECK( + rs.runtime_cache->cache.get() != nullptr, "Failed to create IRuntimeCache for shared RuntimeCacheHandle"); + } + if (config->setRuntimeCache(*rs.runtime_cache->cache)) { + LOG_DEBUG("Attached external IRuntimeCache to IRuntimeConfig."); } else { - try { - load_runtime_cache(runtime_cache_path, runtime_cache.get()); - } catch (const std::exception& e) { - LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); - } - if (config->setRuntimeCache(*runtime_cache)) { - LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); - } else { - LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); - runtime_cache.reset(); - } + LOG_WARNING("Failed to attach IRuntimeCache to IRuntimeConfig; cache will be unused."); } } else { - LOG_DEBUG("Runtime cache disabled (no path configured)."); + LOG_DEBUG("Runtime cache disabled (no RuntimeCacheHandle provided)."); } - // Dynamic shapes kernel specialization strategy -- TRT-RTX only. config->setDynamicShapesKernelSpecializationStrategy( - static_cast(dynamic_shapes_kernel_strategy)); - LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + to_trt_ds_strategy(rs.dynamic_shapes_kernel_specialization_strategy)); + LOG_DEBUG( + "Dynamic shapes kernel specialization strategy set to " << rs.dynamic_shapes_kernel_specialization_strategy); - // CUDA graph strategy -- TRT-RTX only. - if (!config->setCudaGraphStrategy( - cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture - ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE - : nvinfer1::CudaGraphStrategy::kDISABLED)) { + if (!config->setCudaGraphStrategy(to_trt_cg_strategy(rs.cuda_graph_strategy))) { LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); } #endif #endif // TRT_HAS_IRUNTIME_CONFIG } +void TRTRuntimeConfig::reset() { +#ifdef TRT_HAS_IRUNTIME_CONFIG + config.reset(); +#endif +} + std::shared_ptr TRTRuntimeConfig::create_execution_context( nvinfer1::ICudaEngine* cuda_engine, + RuntimeSettings const& rs, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy) { - ensure_initialized(cuda_engine); + ensure_initialized(cuda_engine, rs); #ifdef TRT_HAS_IRUNTIME_CONFIG config->setExecutionContextAllocationStrategy(allocation_strategy); return make_trt(cuda_engine->createExecutionContext(config.get())); @@ -163,92 +102,44 @@ std::shared_ptr TRTRuntimeConfig::create_execution_ #endif } -bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { +bool TRTRuntimeConfig::uses_internal_capture( + TORCHTRT_UNUSED RuntimeSettings const& rs, + TORCHTRT_UNUSED bool cudagraphs_enabled) noexcept { #ifdef TRT_MAJOR_RTX // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled - // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the - // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal - // capture would collide with it. - return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; + // strategy is set, or when subgraph cudagraphs are enabled globally. In both + // cases the caller should skip its manual at::cuda::CUDAGraph wrapper. + return rs.cuda_graph_strategy != "disabled" || cudagraphs_enabled; #else return false; #endif } -void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { -#ifdef TRT_MAJOR_RTX - if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { - return; - } - LOG_WARNING( - "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " - << engine_name << " for the remainder of its lifetime."); - // Persist any kernels the engine-internal capture has compiled so far; the outer - // capture will run without them otherwise, and we want future reloads to reuse them. - save_runtime_cache(); - cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; - if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { - LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); - } - rtx_native_cudagraphs_disabled = true; -#endif -} - bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED RuntimeSettings const& rs, + TORCHTRT_UNUSED bool has_dynamic_inputs, TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, - TORCHTRT_UNUSED cudaStream_t stream) const { + TORCHTRT_UNUSED cudaStream_t stream) noexcept { #ifdef TRT_MAJOR_RTX TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); if (!exec_ctx->isStreamCapturable(stream)) { return false; } - // "lazy" kernel specialization only swaps specialized kernels mid-run when an input - // has a dynamic dimension; for static-shape engines the kernels are fixed at setup and - // the captured graph stays valid. Mirrors the Python `_is_monolithic_capturable` check. - return !(dynamic_shapes_kernel_strategy == DynamicShapesKernelStrategy::kLazy && has_dynamic_inputs); + // "lazy" kernel specialization only swaps specialized kernels mid-run when an + // input has a dynamic dimension; for static-shape engines the kernels are fixed + // at setup and the captured graph stays valid. Mirrors the Python check. + return !(rs.dynamic_shapes_kernel_specialization_strategy == "lazy" && has_dynamic_inputs); #else return true; #endif } -void TRTRuntimeConfig::save_runtime_cache() noexcept { -#ifdef TRT_MAJOR_RTX - if (!runtime_cache || runtime_cache_path.empty()) { - return; - } - try { - save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); - } catch (const std::exception& e) { - LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); - } catch (...) { - LOG_WARNING("Failed to save runtime cache (unknown exception)."); - } -#endif -} - -std::string TRTRuntimeConfig::to_str() const { - std::ostringstream os; - os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; - os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; - os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; - return os.str(); -} - -TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info) { - TRTRuntimeConfig cfg; - if (info[HAS_RUNTIME_CFG_IDX] == "1") { - cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; - cfg.dynamic_shapes_kernel_strategy = - to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); - cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); - } - return cfg; -} - std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { - os << "Runtime cfg {" << std::endl; - os << cfg.to_str(); - os << "}" << std::endl; + os << "TRTRuntimeConfig{"; +#ifdef TRT_HAS_IRUNTIME_CONFIG + os << "config=" << (cfg.config ? "live" : "null"); +#endif + os << "}"; return os; } diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 6e7b8bc6ab..94e9b3b5ac 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -1,12 +1,9 @@ #pragma once #include -#include #include #include #include -#include -#include #include "NvInfer.h" @@ -14,86 +11,53 @@ namespace torch_tensorrt { namespace core { namespace runtime { -// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. -enum class DynamicShapesKernelStrategy : int32_t { - kLazy = 0, - kEager = 1, - kNone = 2, -}; - -// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. -enum class CudaGraphStrategyOption : int32_t { - kDisabled = 0, - kWholeGraphCapture = 1, -}; +struct RuntimeSettings; -// Encapsulates the IRuntimeConfig and TRT-RTX runtime state for a TRTEngine. -// IRuntimeConfig and runtime-cache `#ifdef`s are confined to this TU; serialization- -// index plumbing keeps its own RTX gates elsewhere. +// Owns the live `IRuntimeConfig` (where supported) and the engine-local fallback +// `IRuntimeCache` used when no external `RuntimeCacheHandle` is attached via +// `RuntimeSettings`. The settings themselves (strategy strings, runtime_cache +// handle) live on `RuntimeSettings`; this struct applies them to TRT at +// `ensure_initialized` time. +// +// `IRuntimeConfig` and runtime-cache `#ifdef`s are confined to this TU. struct TRTRuntimeConfig { - // Settings - typically populated from engine deserialization before `ensure_initialized`. - std::string runtime_cache_path = ""; - DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; - CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; - - // One-shot: set to true once an outer stream capture has been detected and the - // engine-internal CUDA graph strategy has been disabled for the remainder of the - // owning engine's lifetime. - bool rtx_native_cudagraphs_disabled = false; - - bool has_dynamic_inputs = true; - - // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized` - // and is unavailable on TensorRT versions older than 10.11 (e.g. Jetpack). + // Lazy-constructed live config. `nullptr` until first `ensure_initialized`. #ifdef TRT_HAS_IRUNTIME_CONFIG std::shared_ptr config; #endif -#ifdef TRT_MAJOR_RTX - std::shared_ptr runtime_cache; -#endif - // Lazily construct the IRuntimeConfig and apply RTX-specific settings. Idempotent. - // No-op on builds without IRuntimeConfig (e.g. Jetpack). - void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); + // (Re)build the `IRuntimeConfig` from `rs`. Idempotent only if the previous + // `rs` was identical. Callers ensure the engine is the same across calls -- + // we don't memoize against `cuda_engine` here. + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine, RuntimeSettings const& rs); - // Lazy-initialize the IRuntimeConfig if needed and create an IExecutionContext that - // honors `allocation_strategy`. Selects the right `createExecutionContext` overload - // (IRuntimeConfig* vs ExecutionContextAllocationStrategy) so callers stay free of - // any TRT_HAS_IRUNTIME_CONFIG branching. + // Force the next `ensure_initialized` to rebuild from scratch. Used when + // settings change at runtime. + void reset(); + + // Lazy-init + create a fresh `IExecutionContext` honoring `allocation_strategy`. + // Picks the right `createExecutionContext` overload (IRuntimeConfig* vs + // ExecutionContextAllocationStrategy) so callers stay free of any + // `TRT_HAS_IRUNTIME_CONFIG` branching. [[nodiscard]] std::shared_ptr create_execution_context( nvinfer1::ICudaEngine* cuda_engine, + RuntimeSettings const& rs, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy); - // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the - // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always - // false on non-RTX builds. - [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; - - // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream - // capture is detected around execute_engine, so the outer capture can contain the - // kernel launches directly. Saves the runtime cache before recreating the context so - // compiled kernels from the present run are preserved for future reloads. - void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; - - // Whether the execution context is safe to include in an outer monolithic capture. - // Non-RTX builds always return true. - [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; - - // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a - // destructor. The underlying file I/O is performed by free functions declared below - // (non-noexcept, exception-leaky for easier testing); this member wraps them and - // swallows any exceptions. - void save_runtime_cache() noexcept; - - // Returns a human-readable summary of the runtime config. - [[nodiscard]] std::string to_str() const; + // Returns true if TRT-RTX owns capture/replay for the given settings -- caller + // should then bypass its own `at::cuda::CUDAGraph` capture around enqueueV3. + // Always false on non-RTX builds. + [[nodiscard]] static bool uses_internal_capture(RuntimeSettings const& rs, bool cudagraphs_enabled) noexcept; + + // Returns true iff the execution context can be safely included in an outer + // monolithic capture. Non-RTX builds always return true. + [[nodiscard]] static bool is_monolithic_capturable( + RuntimeSettings const& rs, + bool has_dynamic_inputs, + nvinfer1::IExecutionContext* exec_ctx, + cudaStream_t stream) noexcept; }; -// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the -// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized -// struct. -[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); - std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 80936951ef..a773b1afa3 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -249,7 +249,7 @@ std::vector execute_engine(std::vector inputs, c10::intr // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would // collide, so we disable it one-shot here. bool effective_cudagraphs = cudagraphs_enabled; - if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { + if (TRTRuntimeConfig::uses_internal_capture(compiled_engine->runtime_settings(), cudagraphs_enabled)) { effective_cudagraphs = false; cudaStreamCaptureStatus capture_status; cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 44d1b314ca..e4294ab754 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -1,6 +1,7 @@ #include #include "core/runtime/Platform.h" +#include "core/runtime/RuntimeSettings.h" #include "core/runtime/runtime.h" #include "core/util/macros.h" @@ -13,6 +14,17 @@ namespace core { namespace runtime { namespace { + +// Register `RuntimeCacheHandle` as a torchbind class so Python can pass the same +// underlying `IRuntimeCache` to both Python and C++ engine backends. File I/O on +// the handle is the Python side's responsibility; the C++ class only holds the +// shared_ptr and an informational path string. +static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = + torch::class_("tensorrt", "RuntimeCacheHandle") + .def(torch::init()) + .def("path", &RuntimeCacheHandle::path) + .def("set_path", &RuntimeCacheHandle::set_path); + // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { // auto input_vec = inputs.vec(); @@ -47,6 +59,18 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic : TRTEngine::ResourceAllocationStrategy::kStatic); }) + .def( + "update_runtime_settings", + [](const c10::intrusive_ptr& self, + std::string const& dynamic_shapes_kernel_specialization_strategy, + std::string const& cuda_graph_strategy, + c10::intrusive_ptr runtime_cache) -> void { + RuntimeSettings rs; + rs.dynamic_shapes_kernel_specialization_strategy = dynamic_shapes_kernel_specialization_strategy; + rs.cuda_graph_strategy = cuda_graph_strategy; + rs.runtime_cache = std::move(runtime_cache); + self->update_runtime_settings(std::move(rs)); + }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("pre_allocated_outputs", &TRTEngine::pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) @@ -147,10 +171,6 @@ TORCH_LIBRARY(tensorrt, m) { return false; #endif }); - m.def("HAS_RUNTIME_CFG_IDX", []() -> int64_t { return HAS_RUNTIME_CFG_IDX; }); - m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); - m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); - m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 2cbe73d6da..a87bd2ca2a 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -41,11 +41,6 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, - // HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. - HAS_RUNTIME_CFG_IDX, - RUNTIME_CACHE_PATH_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, - CUDA_GRAPH_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -62,10 +57,6 @@ inline constexpr std::array kSerializedInfoIndex "REQUIRES_OUTPUT_ALLOCATOR_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", - "HAS_RUNTIME_CFG_IDX", - "RUNTIME_CACHE_PATH_IDX", - "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", - "CUDA_GRAPH_STRATEGY_IDX", }}; // For adding new serialized info indices, update above and update /dynamo/runtime/_serialized_engine_layout.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 8671dd5860..17e0154c68 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,9 +5,12 @@ import os import platform import warnings -from typing import Any, Collection, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Collection, List, Optional, Sequence, Union import torch + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -89,9 +92,6 @@ def cross_compile_for_windows( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -112,6 +112,7 @@ def cross_compile_for_windows( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, + runtime_settings: Optional["RuntimeSettings"] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -170,9 +171,6 @@ def cross_compile_for_windows( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -318,9 +316,6 @@ def cross_compile_for_windows( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -354,12 +349,6 @@ def cross_compile_for_windows( f"arg: {key} is not supported for cross compilation for windows feature, hence it is disabled." ) - if "runtime_cache_path" in compilation_options: - compilation_options.pop("runtime_cache_path") - logger.warning( - "runtime_cache_path is a JIT-time API and is not applicable to cross compilation for windows. Ignoring." - ) - settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -398,6 +387,7 @@ def cross_compile_for_windows( trt_arg_inputs, trt_kwarg_inputs, settings, + runtime_settings=runtime_settings, ) return trt_gm @@ -433,9 +423,6 @@ def compile( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -469,6 +456,7 @@ def compile( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, + runtime_settings: Optional["RuntimeSettings"] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -529,9 +517,6 @@ def compile( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -709,9 +694,6 @@ def compile( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -775,7 +757,12 @@ def compile( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) trt_gm = compile_module( - gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache + gm, + trt_arg_inputs, + trt_kwarg_inputs, + settings, + engine_cache, + runtime_settings=runtime_settings, ) return trt_gm @@ -889,6 +876,7 @@ def compile_module( engine_cache: Optional[BaseEngineCache] = None, *, _debugger_config: Optional[DebuggerConfig] = None, + runtime_settings: Optional["RuntimeSettings"] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -1126,6 +1114,7 @@ def preserve_module_specs( settings=settings, name=name, engine_cache=engine_cache, + runtime_settings=runtime_settings, ) trt_modules[name] = trt_module @@ -1230,9 +1219,6 @@ def convert_exported_program_to_serialized_trt_engine( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1307,9 +1293,6 @@ def convert_exported_program_to_serialized_trt_engine( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1464,9 +1447,6 @@ def convert_exported_program_to_serialized_trt_engine( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1484,11 +1464,6 @@ def convert_exported_program_to_serialized_trt_engine( "attn_bias_is_causal": attn_bias_is_causal, "use_python_runtime": use_python_runtime, } - if "runtime_cache_path" in compilation_options: - compilation_options.pop("runtime_cache_path") - logger.warning( - "runtime_cache_path is a JIT-time API and is not applicable to serialized engine export. Ignoring." - ) settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 00bc39bce5..4a8078dd1d 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -38,9 +38,6 @@ TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) -RUNTIME_CACHE_PATH = os.path.join( - tempfile.gettempdir(), "torch_tensorrt_engine_cache", "runtime_cache.bin" -) LAZY_ENGINE_INIT = False CACHE_BUILT_ENGINES = False REUSE_CACHED_ENGINES = False @@ -69,8 +66,6 @@ DYNAMICALLY_ALLOCATE_RESOURCES = False DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True -DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" -CUDA_GRAPH_STRATEGY = "disabled" USE_PYTHON_RUNTIME = False if platform.system() == "Linux": diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 694b1a7000..9227ad0b81 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,14 +17,12 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, - CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, DYNAMICALLY_ALLOCATE_RESOURCES, ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, @@ -45,7 +43,6 @@ REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, - RUNTIME_CACHE_PATH, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, TILING_OPTIMIZATION_LEVEL, @@ -94,9 +91,6 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. @@ -148,11 +142,6 @@ class CompilationSettings: dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH - runtime_cache_path: str = RUNTIME_CACHE_PATH - dynamic_shapes_kernel_specialization_strategy: str = ( - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY - ) - cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d712d7f150..9ed9b5ef2e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,10 +2,13 @@ import io import logging -from typing import Any, Dict, List, NamedTuple, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence import tensorrt as trt import torch + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input @@ -333,6 +336,7 @@ def convert_module( settings: CompilationSettings = CompilationSettings(), name: str = "", engine_cache: Optional[BaseEngineCache] = None, + runtime_settings: Optional["RuntimeSettings"] = None, ) -> TorchTensorRTModule: """Convert an FX module to a TRT module Args: @@ -341,6 +345,8 @@ def convert_module( settings: Compilation settings name: TRT engine name engine_cache: Engine cache instance + runtime_settings: Optional runtime-mode-control overrides threaded to the + built ``TRTEngine``. Not part of ``CompilationSettings``; not serialized. Returns: TorchTensorRTModule """ @@ -379,4 +385,5 @@ def convert_module( requires_output_allocator=serialized_interpreter_result.requires_output_allocator, requires_native_multidevice=serialized_interpreter_result.requires_native_multidevice, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, + runtime_settings=runtime_settings, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 537444a0b9..7db59aaea6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -145,10 +145,7 @@ def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( TorchTensorRTModule, ) - from torch_tensorrt.dynamo.runtime._TRTEngine import ( - TRTEngine, - _get_cuda_graph_strategy, - ) + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine for name, mod in self.compiled_module.named_modules(): if not ( @@ -170,10 +167,8 @@ def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: # Disable RTX-native CUDA graphs on this engine so they don't # interfere with the outer monolithic capture. if engine._rtx_native_cudagraphs: - engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - "disabled" - ) - engine.context = engine._create_execution_context() + disabled = engine.runtime_settings.merge(cuda_graph_strategy="disabled") + engine.update_runtime_settings(disabled) engine._rtx_native_cudagraphs = False logger.info( f"Disabled RTX-native CUDA graphs for '{name}' " diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index e8496f2d3d..0396c9cc8b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -11,16 +11,28 @@ import base64 import copy import logging -import os import pickle import tempfile from contextlib import nullcontext from types import SimpleNamespace -from typing import Any, ContextManager, Dict, List, Optional, Sequence, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + ContextManager, + Dict, + List, + Optional, + Sequence, + Tuple, + cast, +) import torch import torch.distributed as dist import torch_tensorrt + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings from torch._library.opaque_object import register_opaque_type from torch._opaque_base import OpaqueBase from torch_tensorrt._enums import dtype @@ -76,6 +88,27 @@ def _get_cuda_graph_strategy(strategy_str: str) -> Any: }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) +def _normalize_runtime_cache( + rc: Any, +) -> Any: + """Accept ``None``, a path string, or a ``RuntimeCacheHandle``; return either + ``None`` or a ``RuntimeCacheHandle`` instance. + + String inputs are wrapped in a fresh per-engine implicit handle. The handle + is owned by the engine (saved on engine ``__del__``). + """ + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + if rc is None or isinstance(rc, RuntimeCacheHandle): + return rc + if isinstance(rc, str): + return RuntimeCacheHandle(path=rc, autosave=True) + raise TypeError( + f"RuntimeSettings.runtime_cache must be None, a path string, or a " + f"RuntimeCacheHandle; got {type(rc).__name__}" + ) + + # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -218,7 +251,11 @@ def __init__( serialized_info: SerializedTensorRTEngineFmt, *, profile_execution: bool = False, + runtime_settings: Optional["RuntimeSettings"] = None, ) -> None: + # Import here to avoid a circular dep at module-import time. + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + self._profile_execution = profile_execution self.profile_path_prefix = tempfile.gettempdir() self.use_pre_allocated_outputs = False @@ -239,10 +276,18 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # Initialized to ``None`` here so the destructor can safely save the - # cache even if ``_setup_engine`` never runs. + # Initialized to ``None`` here so the destructor can run even if + # ``_setup_engine`` never executed. self.runtime_config: Any = None - self.runtime_cache: Any = None + # Per-engine implicit cache handle, owned by this engine when + # ``runtime_settings.runtime_cache`` is supplied as a string path. + # ``None`` when ``runtime_settings.runtime_cache`` is an external + # handle (caller owns the lifecycle). + self._implicit_cache_handle: Any = None + # Engine-local IRuntimeCache used when no external handle is attached. + # Held as an instance attr so its lifetime matches the runtime_config it + # was set on -- TRT's set_runtime_cache borrows, doesn't own. + self._engine_local_runtime_cache: Any = None # When true, ``_execute_standard`` must skip manual torch.cuda.CUDAGraph # capture because TRT-RTX handles it internally. self._rtx_native_cudagraphs: bool = False @@ -250,6 +295,9 @@ def __init__( # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None + # User-facing runtime settings. Mutated by ``update_runtime_settings``. + self.runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() + self._load_serialized_info(serialized_info) self._setup_engine() @@ -285,6 +333,8 @@ def __getstate__(self) -> Tuple[List[Any], str]: def __setstate__(self, state: Any) -> None: """Restore from C++-matching pickle state ``(serialized_info,)``.""" + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + self._profile_execution = False self.profile_path_prefix = tempfile.gettempdir() self.use_pre_allocated_outputs = False @@ -308,11 +358,16 @@ def __setstate__(self, state: Any) -> None: # See ``__init__`` for the rationale: pre-init these so a destructor # firing on a partially-loaded engine never trips an ``AttributeError``. self.runtime_config = None - self.runtime_cache = None + self._implicit_cache_handle = None + self._engine_local_runtime_cache = None self._rtx_native_cudagraphs = False # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). self._nccl_comm = None + # RuntimeSettings are NOT serialized -- restore defaults. Callers + # who want runtime-mode overrides must reapply them post-load via + # ``compiled.set_runtime_settings(...)`` or a runtime CM. + self.runtime_settings = RuntimeSettings() serialized_info = list(state[0]) engine_field = serialized_info[ENGINE_IDX] @@ -401,8 +456,18 @@ def get_serialized_metadata(self) -> str: return self.serialized_metadata def close(self) -> None: - """Persist the runtime cache and release CUDA graph resources.""" - self._save_runtime_cache() + """Persist any implicit runtime cache and release CUDA graph resources. + + Implicit handles (created by the engine from a string path in + ``runtime_settings.runtime_cache``) save here. External handles + from a ``runtime_cache`` CM save on the CM's ``__exit__`` instead. + """ + handle = self._implicit_cache_handle + if handle is not None: + try: + handle.save() + except Exception as e: # never raise from __del__ + logger.warning(f"Failed to save implicit runtime cache: {e}") self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: @@ -436,7 +501,7 @@ def _setup_engine(self) -> None: if ENABLED_FEATURES.tensorrt_rtx: self._setup_runtime_config() self._rtx_native_cudagraphs = ( - self.settings.cuda_graph_strategy != "disabled" + self.runtime_settings.cuda_graph_strategy != "disabled" ) self.context = self._create_execution_context() @@ -506,12 +571,12 @@ def _setup_engine(self) -> None: # --- TensorRT-RTX --- def _setup_runtime_config(self) -> None: - """Build an ``IRuntimeConfig`` with runtime cache and dynamic-shape strategy. + """Build an ``IRuntimeConfig`` sourced from ``self.runtime_settings``. - The runtime cache stores JIT compilation results so kernel/graph - compilation is not repeated across inference runs. The dynamic-shape - kernel specialization strategy controls how the JIT compiles - shape-specialized kernels (``lazy``, ``eager``, or ``none``). + The runtime cache field on RuntimeSettings can be ``None`` (per-engine + in-memory cache), a string path (engine creates an implicit handle and + saves on ``__del__``), or a ``RuntimeCacheHandle`` (external; caller + owns lifecycle). """ self.runtime_config = self.cuda_engine.create_runtime_config() alloc_strategy = ( @@ -522,63 +587,83 @@ def _setup_runtime_config(self) -> None: self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( _get_dynamic_shapes_kernel_strategy( - self.settings.dynamic_shapes_kernel_specialization_strategy + self.runtime_settings.dynamic_shapes_kernel_specialization_strategy ) ) logger.info( "Dynamic shapes kernel specialization strategy: " - f"{self.settings.dynamic_shapes_kernel_specialization_strategy}" + f"{self.runtime_settings.dynamic_shapes_kernel_specialization_strategy}" ) self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - self.settings.cuda_graph_strategy + self.runtime_settings.cuda_graph_strategy ) - logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") - self.runtime_cache = self.runtime_config.create_runtime_cache() - self._load_runtime_cache() - self.runtime_config.set_runtime_cache(self.runtime_cache) - logger.info("TensorRT-RTX runtime cache configured") - - def _load_runtime_cache(self) -> None: - """Load runtime cache from disk if it exists (with a shared file lock).""" - if self.runtime_cache is None: - return - cache_path = self.settings.runtime_cache_path - if not os.path.isfile(cache_path): - logger.debug(f"No existing runtime cache at {cache_path}") - return - try: - from filelock import FileLock - - lock = FileLock(cache_path + ".lock") - with lock.acquire(timeout=10): - with open(cache_path, "rb") as f: - data = f.read() - if data: - self.runtime_cache.deserialize(data) - logger.info(f"Loaded runtime cache from {cache_path}") - except Exception as e: - logger.warning(f"Failed to load runtime cache: {e}") - - def _save_runtime_cache(self) -> None: - """Save runtime cache to disk (with an exclusive file lock).""" - if self.runtime_cache is None: + logger.info(f"CUDA graph strategy: {self.runtime_settings.cuda_graph_strategy}") + + # Resolve the runtime cache. We only attach a cache to the runtime_config + # when the user explicitly opts in: passing a path string (engine creates + # an implicit handle, saves on ``__del__``) or a ``RuntimeCacheHandle`` + # (external, caller-managed). Default ``None`` leaves the runtime_config + # cache-less, matching pre-refactor behavior. + rc = self.runtime_settings.runtime_cache + if rc is None: + self._implicit_cache_handle = None + self._engine_local_runtime_cache = None + logger.debug( + "Runtime cache disabled (no RuntimeCacheHandle / path provided)." + ) + elif isinstance(rc, str): + # Per-engine disk-backed cache; engine owns the handle and saves + # on ``__del__`` (matches today's ``runtime_cache_path=`` semantics). + # We MUST keep a Python ref to the cache (TRT's ``set_runtime_cache`` + # only borrows) -- the handle holds it. + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + cache = self.runtime_config.create_runtime_cache() + self._implicit_cache_handle = RuntimeCacheHandle( + cache=cache, path=rc, autosave=True + ) + self._engine_local_runtime_cache = None + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.warning(f"Failed to load implicit runtime cache: {e}") + self.runtime_config.set_runtime_cache(cache) + else: + # External handle. Lifecycle owned by caller; the handle holds the ref. + cache = rc.ensure_cache(self.runtime_config) + self._implicit_cache_handle = None + self._engine_local_runtime_cache = None + self.runtime_config.set_runtime_cache(cache) + logger.info("TensorRT-RTX runtime config configured") + + def update_runtime_settings(self, new_settings: "RuntimeSettings") -> None: + """Apply new ``RuntimeSettings`` to this engine. + + No-op fast-path when ``new_settings`` is field-equal to the current + settings. Otherwise: persist any prior implicit-cache contents, + rebuild ``runtime_config`` from the new settings, and recreate the + execution context so the new strategy values take effect on the next + enqueue. + """ + if new_settings == self.runtime_settings: return - try: - host_mem = self.runtime_cache.serialize() - if host_mem is None: - return - cache_path = self.settings.runtime_cache_path - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - - from filelock import FileLock - - lock = FileLock(cache_path + ".lock") - with lock.acquire(timeout=10): - with open(cache_path, "wb") as f: - f.write(memoryview(host_mem)) - logger.info(f"Saved runtime cache to {cache_path}") - except Exception as e: - logger.warning(f"Failed to save runtime cache: {e}") + # Persist the prior implicit cache before swapping handles; otherwise + # the str-path lifecycle would silently drop kernels JIT-compiled so + # far when the user moves to a different cache configuration. + prior_handle = self._implicit_cache_handle + if prior_handle is not None: + try: + prior_handle.save() + except Exception as e: + logger.warning(f"Failed to save implicit runtime cache on swap: {e}") + self.runtime_settings = new_settings + if ENABLED_FEATURES.tensorrt_rtx: + self._setup_runtime_config() + self._rtx_native_cudagraphs = ( + self.runtime_settings.cuda_graph_strategy != "disabled" + ) + self.context = self._create_execution_context() + self.runtime_states.context_changed = True def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. @@ -592,7 +677,8 @@ def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: not_capturable = ( not self.context.is_stream_capturable(stream.cuda_stream), ( - self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy" + self.runtime_settings.dynamic_shapes_kernel_specialization_strategy + == "lazy" and has_dynamic_input ), ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 77253f3deb..5483f16347 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -4,7 +4,7 @@ import copy import logging import pickle -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch from torch_tensorrt._Device import Device @@ -14,11 +14,8 @@ from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( ABI_TARGET_IDX, ABI_VERSION, - CUDA_GRAPH_STRATEGY_IDX, DEVICE_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, ENGINE_IDX, - HAS_RUNTIME_CFG_IDX, HW_COMPATIBLE_IDX, INPUT_BINDING_NAMES_IDX, NAME_IDX, @@ -26,7 +23,6 @@ REQUIRES_NATIVE_MULTIDEVICE_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, - RUNTIME_CACHE_PATH_IDX, SERIALIZATION_LEN, SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, @@ -35,6 +31,9 @@ serialize_device_info, ) +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + logger = logging.getLogger(__name__) SerializedTorchTensorRTModuleFmt = Tuple[ @@ -44,16 +43,6 @@ List[str], ] -_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { - "lazy": 0, - "eager": 1, - "none": 2, -} -_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { - "disabled": 0, - "whole_graph_capture": 1, -} - class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """``nn.Module`` that runs a TensorRT engine inside PyTorch. @@ -79,6 +68,7 @@ def __init__( requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, + runtime_settings: Optional["RuntimeSettings"] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -146,25 +136,12 @@ def __init__( self.execute_engine_op: Any = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources - self.runtime_cache_path = settings.runtime_cache_path - self.dynamic_shapes_kernel_specialization_strategy = ( - settings.dynamic_shapes_kernel_specialization_strategy - ) - if ( - self.dynamic_shapes_kernel_specialization_strategy - not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP - ): - raise ValueError( - f"Invalid dynamic_shapes_kernel_specialization_strategy " - f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " - f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" - ) - self.cuda_graph_strategy = settings.cuda_graph_strategy - if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: - raise ValueError( - f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " - f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" - ) + + # Per-engine runtime mode controls. Defaults to ``RuntimeSettings()`` if + # not supplied; the dataclass validates at ``__post_init__``. + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + + self._runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice self.target_platform = ( @@ -261,18 +238,9 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = str( int(self.requires_native_multidevice) ) - # rank/world_size are runtime facts; queried from ProcessGroup at execution time - engine_info[HAS_RUNTIME_CFG_IDX] = "1" if ENABLED_FEATURES.tensorrt_rtx else "0" - engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" - engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( - _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ - self.dynamic_shapes_kernel_specialization_strategy - ] - ) - engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( - _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] - ) - + # rank/world_size are runtime facts; queried from ProcessGroup at execution time. + # RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory + # init values, not part of the engine's identity (see pytorch/TensorRT#4310). return engine_info def get_streamable_device_memory_budget(self) -> Any: @@ -306,6 +274,52 @@ def use_dynamically_allocated_resources( self.dynamically_allocate_resources ) + # --- runtime-settings dispatch ---------------------------------------- + + @property + def runtime_settings(self) -> "RuntimeSettings": + """The current ``RuntimeSettings`` on this module (and its engine). + + This is the snapshot the ``runtime_config`` CM reads at ``__enter__`` + and restores at ``__exit__``. + """ + return self._runtime_settings + + def set_runtime_settings(self, rs: "RuntimeSettings") -> None: + """Apply ``RuntimeSettings`` to all TRT engines under this module. + + Walks ``named_modules()`` so calling on a wrapper / parent + ``nn.Module`` propagates to every contained + ``TorchTensorRTModule``. Dispatches to the Python ``TRTEngine`` or + the C++ ``torch.classes.tensorrt.Engine`` per submodule's backend. + """ + for _, mod in self.named_modules(): + if isinstance(mod, TorchTensorRTModule) and mod.engine is not None: + mod._dispatch_runtime_settings_to_engine(rs) + mod._runtime_settings = rs + + def _dispatch_runtime_settings_to_engine(self, rs: "RuntimeSettings") -> None: + """Backend-aware dispatch of ``update_runtime_settings(rs)`` to ``self.engine``.""" + if self.engine is None: + return + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + + if isinstance(self.engine, TRTEngine): + # Python runtime: dataclass passes straight through. + self.engine.update_runtime_settings(rs) + return + + # C++ torchbind engine: flatten the dataclass into positional args. The + # cache field is converted to a torchbind RuntimeCacheHandle (or None). + from torch_tensorrt.runtime._runtime_cache import _to_torchbind_handle + + cache_arg = _to_torchbind_handle(rs.runtime_cache) + self.engine.update_runtime_settings( + rs.dynamic_shapes_kernel_specialization_strategy, + rs.cuda_graph_strategy, + cache_arg, + ) + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. @@ -324,11 +338,14 @@ def setup_engine(self) -> None: self.engine = TRTEngine( self._pack_engine_info(), profile_execution=self.profiling_enabled, + runtime_settings=self._runtime_settings, ) self.execute_engine_op = torch.ops.tensorrt.execute_engine_python else: self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) self.execute_engine_op = torch.ops.tensorrt.execute_engine + # Apply runtime settings to the C++ engine (no-op if defaults). + self._dispatch_runtime_settings_to_engine(self._runtime_settings) # requires_native_multidevice is set by the C++ constructor from the serialized REQUIRES_NATIVE_MULTIDEVICE_IDX field. if self.engine.requires_native_multidevice: @@ -432,10 +449,17 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: getattr(self.settings, "use_python_runtime", False) or not ENABLED_FEATURES.torch_tensorrt_runtime ) + # RuntimeSettings are NOT serialized; restore defaults. Caller can + # reapply via ``compiled.set_runtime_settings(...)`` or a CM after load. + from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + + self._runtime_settings = RuntimeSettings() if self._use_python_runtime: from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine - self.engine = TRTEngine(serialized_engine_info) + self.engine = TRTEngine( + serialized_engine_info, runtime_settings=self._runtime_settings + ) self.execute_engine_op = torch.ops.tensorrt.execute_engine_python else: self.engine = torch.classes.tensorrt.Engine(serialized_engine_info) diff --git a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py index c0bc6653b9..d4f31ba8a8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py +++ b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py @@ -37,11 +37,6 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = 9 RESOURCE_ALLOCATION_STRATEGY_IDX = 10 REQUIRES_NATIVE_MULTIDEVICE_IDX = 11 - # HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. - HAS_RUNTIME_CFG_IDX = 12 - RUNTIME_CACHE_PATH_IDX = 13 - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = 14 - CUDA_GRAPH_STRATEGY_IDX = 15 # Module-level aliases for backward compatibility and concise access @@ -57,12 +52,6 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = SerializedInfoIndex.REQUIRES_OUTPUT_ALLOCATOR_IDX RESOURCE_ALLOCATION_STRATEGY_IDX = SerializedInfoIndex.RESOURCE_ALLOCATION_STRATEGY_IDX REQUIRES_NATIVE_MULTIDEVICE_IDX = SerializedInfoIndex.REQUIRES_NATIVE_MULTIDEVICE_IDX -HAS_RUNTIME_CFG_IDX = SerializedInfoIndex.HAS_RUNTIME_CFG_IDX -RUNTIME_CACHE_PATH_IDX = SerializedInfoIndex.RUNTIME_CACHE_PATH_IDX -DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( - SerializedInfoIndex.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX -) -CUDA_GRAPH_STRATEGY_IDX = SerializedInfoIndex.CUDA_GRAPH_STRATEGY_IDX SERIALIZATION_LEN = len(SerializedInfoIndex) SERIALIZED_ENGINE_BINDING_DELIM = "%" @@ -84,10 +73,6 @@ class SerializedInfoIndex(IntEnum): ("REQUIRES_OUTPUT_ALLOCATOR_IDX", "REQUIRES_OUTPUT_ALLOCATOR_IDX", int), ("RESOURCE_ALLOCATION_STRATEGY_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", int), ("REQUIRES_NATIVE_MULTIDEVICE_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", int), - ("HAS_RUNTIME_CFG_IDX", "HAS_RUNTIME_CFG_IDX", int), - ("RUNTIME_CACHE_PATH_IDX", "RUNTIME_CACHE_PATH_IDX", int), - ("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", int), - ("CUDA_GRAPH_STRATEGY_IDX", "CUDA_GRAPH_STRATEGY_IDX", int), ("SERIALIZATION_LEN", "SERIALIZATION_LEN", int), ("SERIALIZED_ENGINE_BINDING_DELIM", "SERIALIZED_ENGINE_BINDING_DELIM", str), ("SERIALIZED_RT_DEVICE_DELIM", "SERIALIZED_RT_DEVICE_DELIM", str), diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 7283ca0f33..478cca548d 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -1,13 +1,20 @@ from torch_tensorrt.dynamo.runtime import ( # noqa: F401 TorchTensorRTModule, ) +from torch_tensorrt.runtime._cuda_graph_strategy import set_cuda_graph_strategy from torch_tensorrt.runtime._cudagraphs import ( enable_cudagraphs, get_cudagraphs_mode, get_whole_cudagraphs_mode, set_cudagraphs_mode, ) +from torch_tensorrt.runtime._dynamic_shapes_kernel_strategy import ( + set_dynamic_shapes_kernel_strategy, +) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs +from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle, runtime_cache +from torch_tensorrt.runtime._runtime_config import runtime_config +from torch_tensorrt.runtime._runtime_settings import RuntimeSettings from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_cuda_graph_strategy.py b/py/torch_tensorrt/runtime/_cuda_graph_strategy.py new file mode 100644 index 0000000000..60c3bfa29a --- /dev/null +++ b/py/torch_tensorrt/runtime/_cuda_graph_strategy.py @@ -0,0 +1,26 @@ +"""Sugar over ``runtime_config`` for the cuda-graph strategy knob.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, Union + +from torch_tensorrt.runtime._runtime_config import ( + _RuntimeConfigContextManager, + runtime_config, +) + +if TYPE_CHECKING: + import torch + + +def set_cuda_graph_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the cuda-graph strategy on all TRT engines + under ``target_or_targets``. + + Accepts ``"disabled"`` or ``"whole_graph_capture"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config(target_or_targets, cuda_graph_strategy=strategy) diff --git a/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py b/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py new file mode 100644 index 0000000000..84f8180143 --- /dev/null +++ b/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py @@ -0,0 +1,28 @@ +"""Sugar over ``runtime_config`` for the dynamic-shapes kernel strategy knob.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, Union + +from torch_tensorrt.runtime._runtime_config import ( + _RuntimeConfigContextManager, + runtime_config, +) + +if TYPE_CHECKING: + import torch + + +def set_dynamic_shapes_kernel_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the dynamic-shapes kernel specialization + strategy on all TRT engines under ``target_or_targets``. + + Accepts ``"lazy"``, ``"eager"``, or ``"none"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config( + target_or_targets, dynamic_shapes_kernel_specialization_strategy=strategy + ) diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py new file mode 100644 index 0000000000..15a01ed0c0 --- /dev/null +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -0,0 +1,251 @@ +"""Runtime cache handle + ``runtime_cache()`` context manager. + +The handle wraps a ``trt.IRuntimeCache`` plus optional disk-backing config. +Used by: + +* The runtime ``cache()`` CM to attach a SHARED cache across one or more + modules' engines. +* ``RuntimeSettings(runtime_cache=...)`` for compile-time hints (string path + ⇒ engine creates an implicit per-engine handle; ``RuntimeCacheHandle`` ⇒ + external shared handle attached directly). + +File I/O lives entirely on the Python side under a ``filelock`` (this module). +The C++-side ``torch.classes.tensorrt.RuntimeCacheHandle`` is a passive +shared_ptr wrapper used only to cross the Python/C++ boundary. +""" + +from __future__ import annotations + +import logging +import os +import shutil +import threading +from typing import Any, Optional, Sequence, Union + +import torch +import torch_tensorrt + +logger = logging.getLogger(__name__) + +_FILELOCK_TIMEOUT_S = 10.0 + + +class RuntimeCacheHandle: + """Wraps a ``trt.IRuntimeCache`` + optional disk path / autosave config. + + Two ways an instance comes into being: + + 1. **Engine-implicit** (compile-time hint): when an engine sees + ``RuntimeSettings(runtime_cache="/path")``, it materializes a + handle internally during ``_setup_runtime_config`` -- the engine + owns the lifecycle and saves on ``__del__``. + + 2. **Runtime CM** (shared): the :func:`runtime_cache` CM bootstraps from + the first engine under target, creates a cache, wraps it here, and + attaches the handle to all engines for the duration of the ``with`` + block. The CM saves on ``__exit__``. + + Both paths produce the same handle shape; the difference is who owns + the lifecycle. + """ + + def __init__( + self, + cache: Any = None, + path: str = "", + autosave: bool = True, + ) -> None: + # ``cache`` is a ``trt.IRuntimeCache`` once materialized. May be None + # at construction if the handle is built before any engine has had a + # chance to call ``runtime_config.create_runtime_cache()``. + self._cache = cache + self.path = path + self.autosave = autosave + self._lock = threading.Lock() + + @property + def cache(self) -> Any: + """The underlying ``trt.IRuntimeCache``. ``None`` if not yet materialized.""" + return self._cache + + def ensure_cache(self, runtime_config: Any) -> Any: + """Idempotent. First caller materializes via ``runtime_config.create_runtime_cache()``.""" + with self._lock: + if self._cache is None: + self._cache = runtime_config.create_runtime_cache() + return self._cache + + def load(self, path: Optional[str] = None) -> None: + """Read bytes from disk and deserialize into ``self._cache``. + + No-op if ``self._cache`` is None, the resolved path is empty, or the + file doesn't exist (first run). Caller must ensure no enqueue is + concurrently writing (the CM enforces this by ordering load before + engine attach; ``ensure_cache`` is called inside the engine setup). + """ + target = path if path is not None else self.path + if not target or self._cache is None: + return + from filelock import FileLock + + if not os.path.exists(target): + return # first run; nothing to load + with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): + with open(target, "rb") as f: + data = f.read() + if data: + self._cache.deserialize(data) + logger.debug(f"Loaded runtime cache from {target} ({len(data)} bytes)") + + def save(self, path: Optional[str] = None) -> None: + """Serialize ``self._cache`` and write to disk under a filelock. + + No-op if path is empty or cache wasn't materialized. Caller must + ensure no enqueue is concurrently writing (the CM detaches the cache + from all engines before calling save in ``__exit__``). + """ + target = path if path is not None else self.path + if not target or self._cache is None: + return + host_mem = self._cache.serialize() + if host_mem is None or host_mem.nbytes == 0: + return + from filelock import FileLock + + parent = os.path.dirname(target) + if parent: + os.makedirs(parent, exist_ok=True) + tmp = target + ".tmp" + with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): + with open(tmp, "wb") as f: + f.write(memoryview(host_mem)) + shutil.move(tmp, target) + logger.debug(f"Saved runtime cache to {target} ({host_mem.nbytes} bytes)") + + def __eq__(self, other: object) -> bool: + # Identity equality so passing the same handle twice through + # update_runtime_settings is a fast-path no-op. + return self is other + + def __hash__(self) -> int: + return id(self) + + def __repr__(self) -> str: + return ( + f"RuntimeCacheHandle(path={self.path!r}, autosave={self.autosave}, " + f"materialized={self._cache is not None})" + ) + + +class _RuntimeCacheContextManager: + """``with runtime_cache(target, path) as rc:`` -- shared cache CM. + + Bootstraps an ``IRuntimeCache`` from one of the engines under target, + wraps it in a :class:`RuntimeCacheHandle`, loads from disk, attaches to + all engines under all listed targets for the duration of the block, and + saves on exit if ``autosave``. + """ + + def __init__( + self, + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + path: str = "", + autosave: bool = True, + ) -> None: + if isinstance(target_or_targets, torch.nn.Module): + self._targets: tuple[torch.nn.Module, ...] = (target_or_targets,) + else: + self._targets = tuple(target_or_targets) + self.path = path + self.autosave = autosave + self.handle: Optional[RuntimeCacheHandle] = None + self._inner_cm: Any = None + + def __enter__(self) -> RuntimeCacheHandle: + # Defer imports to avoid a circular dependency: + # _runtime_cache -> _runtime_config -> _TorchTensorRTModule -> (indirect) _runtime_cache. + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + from torch_tensorrt.runtime._runtime_config import runtime_config + + # 1. Find a bootstrap engine to materialize the cache from. Try all + # targets in order; first TRT submodule wins. + bootstrap_engine = None + for target in self._targets: + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule) and isinstance( + mod.engine, TRTEngine + ): + bootstrap_engine = mod.engine + break + if bootstrap_engine is not None: + break + if bootstrap_engine is None: + raise RuntimeError( + "runtime_cache() requires at least one TorchTensorRTModule under " + "the target(s) using the Python TRT runtime." + ) + + # 2. Materialize the cache via the bootstrap engine's runtime_config. + # (The cache returned is free-floating; ownership transfers to the handle.) + cache_obj = bootstrap_engine.runtime_config.create_runtime_cache() + self.handle = RuntimeCacheHandle( + cache=cache_obj, path=self.path, autosave=self.autosave + ) + + # 3. Load from disk if path was given. + self.handle.load() + + # 4. Apply the handle to ALL engines under target(s) via runtime_config CM. + self._inner_cm = runtime_config(list(self._targets), runtime_cache=self.handle) + self._inner_cm.__enter__() + return self.handle + + def __exit__(self, *args: Any) -> None: + if self._inner_cm is not None: + self._inner_cm.__exit__(*args) + if self.autosave and self.path and self.handle is not None: + self.handle.save() + + +def runtime_cache( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + path: str = "", + autosave: bool = True, +) -> _RuntimeCacheContextManager: + """Context manager that attaches a shared runtime cache to all engines + under ``target_or_targets`` for the duration of the ``with`` block. + + Yields the :class:`RuntimeCacheHandle` for inspection or explicit + ``handle.save()`` calls (e.g., for mid-block checkpointing -- caller is + responsible for ``torch.cuda.synchronize()`` first). + """ + return _RuntimeCacheContextManager(target_or_targets, path, autosave) + + +# When the C++ Torch-TensorRT runtime is loaded, we ALSO expose +# ``torch.classes.tensorrt.RuntimeCacheHandle`` as the canonical +# cross-language handle. The Python class above is the user-facing API; +# at dispatch time the Python module converts to/from the torchbind class as +# needed (see ``TorchTensorRTModule.set_runtime_settings``). +def _to_torchbind_handle( + rc: Union[None, str, "RuntimeCacheHandle"], +) -> Any: + """Convert a Python-side ``runtime_cache`` value to a torchbind handle + suitable for ``torch.classes.tensorrt.Engine.update_runtime_settings(...)``. + + Returns ``None`` if no runtime cache is requested. Raises if the C++ + runtime isn't loaded (caller shouldn't dispatch to a C++ engine in that + case anyway). + """ + if rc is None: + return None + if not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + raise RuntimeError( + "torch_tensorrt C++ runtime is not available; cannot construct " + "torch.classes.tensorrt.RuntimeCacheHandle" + ) + path = rc if isinstance(rc, str) else rc.path + return torch.classes.tensorrt.RuntimeCacheHandle(path) diff --git a/py/torch_tensorrt/runtime/_runtime_config.py b/py/torch_tensorrt/runtime/_runtime_config.py new file mode 100644 index 0000000000..635cdc5902 --- /dev/null +++ b/py/torch_tensorrt/runtime/_runtime_config.py @@ -0,0 +1,91 @@ +"""Per-engine runtime-settings context manager. + +``runtime_config(target_or_targets, **kw)`` is the one runtime CM that toggles +``RuntimeSettings`` on every TRT engine reachable under the listed targets. +Other CMs (``runtime_cache``, ``set_cuda_graph_strategy``, +``set_dynamic_shapes_kernel_strategy``) are thin sugar that delegate here. + +Walks ``named_modules()`` once on enter, snapshots prior settings per engine, +calls ``mod.set_runtime_settings(merged)`` per engine. Restores on exit using +the same snapshot dict. + +Yields the target (or tuple of targets) so users can write +``with runtime_config(model, ...) as m: m(*inputs)``. +""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Dict, Sequence, Tuple, Union + +import torch +from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + + +class _RuntimeConfigContextManager: + def __init__( + self, + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + **overrides: Any, + ) -> None: + # Validate keys against RuntimeSettings field names (typo => raise here, + # not silently no-op later). + valid_fields = {f.name for f in dataclasses.fields(RuntimeSettings)} + unknown = set(overrides) - valid_fields + if unknown: + raise TypeError( + f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " + f"Valid fields: {sorted(valid_fields)}." + ) + + if isinstance(target_or_targets, torch.nn.Module): + self._targets: Tuple[torch.nn.Module, ...] = (target_or_targets,) + self._yield_tuple = False + else: + self._targets = tuple(target_or_targets) + self._yield_tuple = True + self._overrides = overrides + # Engine ↔ prior RuntimeSettings snapshot; populated on enter. + self._saved: Dict[Any, RuntimeSettings] = {} + + def __enter__(self) -> Union[torch.nn.Module, Tuple[torch.nn.Module, ...]]: + # Deferred import to avoid a circular dependency at module-load time. + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for target in self._targets: + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule) and mod.engine is not None: + current = mod.runtime_settings + if mod in self._saved: + # The same TRTModule appears under multiple targets in the + # list (or the tree contains a cycle). Don't snapshot twice. + continue + self._saved[mod] = current + merged = current.merge(**self._overrides) + mod.set_runtime_settings(merged) + return self._targets if self._yield_tuple else self._targets[0] + + def __exit__(self, *args: Any) -> None: + for mod, prior in self._saved.items(): + mod.set_runtime_settings(prior) + + +def runtime_config( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + **overrides: Any, +) -> _RuntimeConfigContextManager: + """Context manager that applies ``RuntimeSettings`` overrides to all TRT + engines under ``target_or_targets`` for the duration of the ``with`` block. + + Accepts the same kwargs as :class:`RuntimeSettings` fields. The pool + semantics collapse N knob changes into one ``update_runtime_settings`` call + per engine, which means exactly two ``IExecutionContext`` recreates per + engine (one on enter, one on exit) regardless of how many overrides are + passed. + + Yields the target module (single form) or a tuple of targets (list form), + by-reference -- same object the caller passed in. + """ + return _RuntimeConfigContextManager(target_or_targets, **overrides) diff --git a/py/torch_tensorrt/runtime/_runtime_settings.py b/py/torch_tensorrt/runtime/_runtime_settings.py new file mode 100644 index 0000000000..738315d555 --- /dev/null +++ b/py/torch_tensorrt/runtime/_runtime_settings.py @@ -0,0 +1,98 @@ +"""User-facing runtime-only knobs for TRT-RTX engines. + +A knob belongs in :class:`RuntimeSettings` iff changing it requires recreating +the ``IExecutionContext``. Per-execute flags (``cudagraphs_mode``, +``multi_device_safe_mode``, ``pre_allocated_outputs``) stay as their existing +process-global setters. + +Three ways to use: + +1. **Compile-time hint** (recommended fast path) -- prime the engine with the + desired initial values so no CM enter/exit recreate is needed:: + + compiled = torchtrt.compile( + model, ..., + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) + +2. **Runtime context manager** -- toggle settings inside a ``with`` block. See + :func:`torch_tensorrt.runtime.runtime_config`. + +3. **Programmatic** -- call ``module.set_runtime_settings(rs)`` directly. + +``RuntimeSettings`` is intentionally NOT part of ``CompilationSettings`` and is +NOT serialized into the engine tuple (per GitHub pytorch/TensorRT#4310). It's +purely an in-memory initialization parameter / runtime override state. +""" + +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + +# Validation maps used by both the engine setup path and the dataclass post-init. +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} + + +@dataclass(frozen=True) +class RuntimeSettings: + """Per-engine runtime-only knobs sampled at IExecutionContext creation. + + Fields: + dynamic_shapes_kernel_specialization_strategy: ``"lazy" | "eager" | "none"``. + TRT-RTX-only; no-op on standard TensorRT. + cuda_graph_strategy: ``"disabled" | "whole_graph_capture"``. TRT-RTX-only. + runtime_cache: ``None``, a disk path string, or a + :class:`RuntimeCacheHandle`. ``None`` ⇒ each engine has an in-memory + cache local to itself. A string is honored at engine construction + time and primes a per-engine disk-backed cache (matches today's + ``runtime_cache_path=`` behavior; saved on engine ``__del__``). + A handle is the shared-cache form, typically obtained from + :func:`torch_tensorrt.runtime.runtime_cache` -- multiple engines + attaching the same handle share one ``IRuntimeCache``. + + Equality compares all fields; for ``runtime_cache``, handle equality is + by identity (same handle ⇒ same cache). + """ + + dynamic_shapes_kernel_specialization_strategy: str = "lazy" + cuda_graph_strategy: str = "disabled" + runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = None # noqa: F821 + + def __post_init__(self) -> None: + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + "Invalid dynamic_shapes_kernel_specialization_strategy: " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}. " + f"Expected one of {list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP)}." + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy: {self.cuda_graph_strategy!r}. " + f"Expected one of {list(_CUDA_GRAPH_STRATEGY_MAP)}." + ) + + def merge(self, **overrides: Any) -> "RuntimeSettings": + """Return a new ``RuntimeSettings`` with ``overrides`` applied on top of self.""" + unknown = set(overrides) - {f.name for f in dataclasses.fields(self)} + if unknown: + raise TypeError( + f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " + f"Valid fields: {[f.name for f in dataclasses.fields(self)]}." + ) + return dataclasses.replace(self, **overrides) diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 2e9855b9a4..c39ccb1b79 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -9,8 +9,9 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.runtime import RuntimeSettings, runtime_cache class SimpleModel(torch.nn.Module): @@ -28,28 +29,35 @@ def forward(self, x): def _fresh_conv_model_and_inputs(seed=0): - """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" torch.manual_seed(seed) return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): - """Compile ``model`` through either runtime. Returns the compiled module.""" - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "use_python_runtime": use_python_runtime, - "min_block_size": 1, - } - if runtime_cache_path is not None: - kwargs["runtime_cache_path"] = runtime_cache_path - compiled = torchtrt.compile(model, **kwargs) + """Compile ``model`` through either runtime. + + ``runtime_cache_path``, when supplied, is threaded as a compile-time hint via + ``runtime_settings=RuntimeSettings(runtime_cache=path)`` (per-engine cache). + """ + rs = ( + RuntimeSettings(runtime_cache=runtime_cache_path) + if runtime_cache_path is not None + else None + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=use_python_runtime, + min_block_size=1, + runtime_settings=rs, + ) torch._dynamo.reset() return compiled def _compile_simple(runtime_cache_path=None): - """Compile SimpleModel on the Python runtime (used by introspection setup tests).""" + """Compile SimpleModel on the Python runtime (used by introspection tests).""" model = SimpleModel().eval().cuda() inputs = [torch.randn(2, 3).cuda()] return ( @@ -64,13 +72,6 @@ def _compile_simple(runtime_cache_path=None): def _find_python_trt_engine(compiled): - """Return the Python ``TRTEngine`` instance from a compiled module, if any. - - The C++ and Python runtimes are now both driven through ``TorchTensorRTModule`` - (``use_python_runtime`` selects which backend is constructed). Tests that target - Python-runtime introspection use this helper; C++-runtime tests rely on - externally observable behavior (cache file on disk, inference correctness). - """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine @@ -80,8 +81,6 @@ def _find_python_trt_engine(compiled): return None -# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ -# variant is skipped inside the test body when the C++ runtime is not available. _RUNTIMES = [("python", True), ("cpp", False)] @@ -95,303 +94,113 @@ def _skip_if_cpp_unavailable(testcase, use_python_runtime): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheSetup(TestCase): - """Tests that runtime config and cache are correctly created for RTX.""" + """Tests that runtime config and per-engine cache are correctly created for RTX.""" def test_runtime_config_created(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine, "No Python TRTEngine found in compiled model") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) - self.assertIsNotNone( - engine.runtime_cache, "runtime_cache should be set for RTX" - ) + self.assertIsNotNone(engine) + self.assertIsNotNone(engine.runtime_config) def test_context_created_successfully(self): - compiled, inputs = _compile_simple() + compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine.context, "execution context should be created") - # Verify inference works - output = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(output.shape, inputs[0].shape) + self.assertIsNotNone(engine.context) - def test_runtime_cache_path_default(self): + def test_no_implicit_cache_handle_by_default(self): + """Default RuntimeSettings has no disk-backing => no implicit handle.""" compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.settings.runtime_cache_path, RUNTIME_CACHE_PATH) + self.assertIsNone(engine._implicit_cache_handle) - def test_runtime_cache_path_custom(self): - cache_dir = tempfile.mkdtemp() - try: - custom_path = os.path.join(cache_dir, "my_cache.bin") - compiled, _ = _compile_simple(runtime_cache_path=custom_path) + def test_implicit_cache_handle_for_path_hint(self): + """Passing a path string in RuntimeSettings.runtime_cache creates an implicit handle.""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + compiled, _ = _compile_simple(runtime_cache_path=path) engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.settings.runtime_cache_path, custom_path) - finally: - shutil.rmtree(cache_dir, ignore_errors=True) + self.assertIsNotNone(engine._implicit_cache_handle) + self.assertEqual(engine._implicit_cache_handle.path, path) @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", + "Runtime cache persistence is RTX-only", ) class TestRuntimeCachePersistence(TestCase): - """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) + """End-to-end: compile with a cache path, infer, destroy, reload, infer again.""" @parameterized.expand(_RUNTIMES) def test_cache_saved_on_del(self, _name, use_python_runtime): _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(self.cache_path), - "Cache should not exist before module cleanup", - ) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(self.cache_path), - "Cache file should be created after module cleanup", - ) - - @parameterized.expand(_RUNTIMES) - def test_cache_file_nonempty(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertGreater( - os.path.getsize(self.cache_path), - 0, - "Cache file should have nonzero size", - ) - - @parameterized.expand(_RUNTIMES) - def test_cache_roundtrip(self, _name, use_python_runtime): - """Populate + save, then recompile and confirm correctness against eager output.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - with torch.no_grad(): - ref_output = model(*inputs) - - compiled1 = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - out1 = compiled1(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out1), - COSINE_THRESHOLD, - "First compiled output should match eager", - ) - del compiled1 - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - - compiled2 = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - out2 = compiled2(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out2), - COSINE_THRESHOLD, - "Second compiled output (warm cache) should still match eager", - ) - - @parameterized.expand(_RUNTIMES) - def test_save_creates_directory(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=nested_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(nested_path), - "Save should create intermediate directories", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", -) -class TestRuntimeCacheConcurrency(TestCase): - """Tests that file locking works for concurrent access.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - - def test_filelock_works(self): - """Verify that filelock can be acquired on the cache path after save.""" - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - # Verify we can acquire a lock on the same path (no deadlock) - from filelock import FileLock - - lock = FileLock(self.cache_path + ".lock") - with lock.acquire(timeout=5): - data = open(self.cache_path, "rb").read() - self.assertGreater(len(data), 0) - - def test_sequential_save_load(self): - """Two modules saving and loading from the same path should not corrupt data.""" - # First module saves - compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) - del compiled1 - gc.collect() - size1 = os.path.getsize(self.cache_path) - - # Second module saves (overwrites) - compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled2(*[inp.clone() for inp in inputs]) - del compiled2 - gc.collect() - size2 = os.path.getsize(self.cache_path) - - self.assertGreater(size1, 0) - self.assertGreater(size2, 0) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + model, inputs = _fresh_conv_model_and_inputs(seed=42) + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=path, + ) + _ = compiled(*inputs) + del compiled + gc.collect() + self.assertTrue( + os.path.exists(path), + f"Implicit cache handle should have saved to {path} on engine __del__", + ) @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, - "Timing cache skip is only relevant for TensorRT-RTX", + "runtime_cache CM is RTX-only", ) -class TestTimingCacheSkipped(TestCase): - """Tests that timing cache is correctly skipped for RTX builds.""" - - def setUp(self): - # Clean up any pre-existing timing cache - if os.path.isfile(TIMING_CACHE_PATH): - os.remove(TIMING_CACHE_PATH) +class TestRuntimeCacheContextManager(TestCase): + """Tests for the runtime_cache(target, path) shared-cache CM.""" - def test_no_timing_cache_file(self): + def test_with_cache_loads_and_saves(self): compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(TIMING_CACHE_PATH), - "Timing cache should NOT be created for RTX builds", - ) - - def test_timing_cache_skip_logged(self): - with self.assertLogs( - "torch_tensorrt.dynamo.conversion._TRTInterpreter", level="INFO" - ) as cm: - compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertTrue( - any("Skipping timing cache" in msg for msg in cm.output), - f"Expected 'Skipping timing cache' log message, got: {cm.output}", - ) - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestNonRTXUnchanged(TestCase): - """Tests that standard TRT behavior is unaffected by the runtime cache changes.""" - - def test_no_runtime_config_for_standard_trt(self): - compiled, _ = _compile_simple() - engine = _find_python_trt_engine(compiled) - if engine is not None: - # The TRT-RTX runtime cache machinery is exposed via the - # ``runtime_config`` / ``runtime_cache`` attributes on the Python - # engine. On non-RTX builds neither should be populated. - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - self.assertIsNone( - engine.runtime_cache, - "runtime_cache should be None for standard TRT", - ) - - def test_timing_cache_still_created(self): - # Clean up any pre-existing timing cache - if os.path.isfile(TIMING_CACHE_PATH): - os.remove(TIMING_CACHE_PATH) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache(compiled, path) as rc: + self.assertIsNotNone(rc) + self.assertEqual(rc.path, path) + _ = compiled(*inputs) + # autosave on exit + self.assertTrue(os.path.exists(path)) + + def test_with_cache_in_memory_only(self): + """path='' means in-memory only; no disk artifact after exit.""" compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertTrue( - os.path.isfile(TIMING_CACHE_PATH), - "Timing cache should still be created for standard TRT", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -class TestSerializationIndices(TestCase): - """The HAS_RUNTIME_CFG flag + TRTRuntimeConfig slots are present on both backends.""" - - def test_indices_match_python_layout(self): - from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( - CUDA_GRAPH_STRATEGY_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, - HAS_RUNTIME_CFG_IDX, - RUNTIME_CACHE_PATH_IDX, - SERIALIZATION_LEN, - ) - - self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), SERIALIZATION_LEN) - self.assertEqual( - int(torch.ops.tensorrt.HAS_RUNTIME_CFG_IDX()), int(HAS_RUNTIME_CFG_IDX) - ) - self.assertEqual( - int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), - int(RUNTIME_CACHE_PATH_IDX), - ) - self.assertEqual( - int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), - int(DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX), - ) - self.assertEqual( - int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), - int(CUDA_GRAPH_STRATEGY_IDX), - ) + with tempfile.TemporaryDirectory() as tmp: + with runtime_cache(compiled, "") as rc: + self.assertEqual(rc.path, "") + _ = compiled(*inputs) + self.assertFalse(os.listdir(tmp), "No files should be created for path=''") + + def test_shared_cache_pointer_across_modules(self): + """Two modules sharing one runtime_cache handle reference the same IRuntimeCache.""" + compiled_a, inputs_a = _compile_simple() + compiled_b, inputs_b = _compile_simple() + eng_a = _find_python_trt_engine(compiled_a) + eng_b = _find_python_trt_engine(compiled_b) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache([compiled_a, compiled_b], path) as rc: + self.assertIs(eng_a.runtime_settings.runtime_cache, rc) + self.assertIs(eng_b.runtime_settings.runtime_cache, rc) + _ = compiled_a(*inputs_a) + _ = compiled_b(*inputs_b) + self.assertTrue(os.path.exists(path)) + + def test_runtime_cache_on_empty_target_raises(self): + """A target with no TRT submodules raises a clear error on enter.""" + empty = torch.nn.Linear(3, 3).cuda() + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + with self.assertRaises(RuntimeError): + with runtime_cache(empty, path): + pass if __name__ == "__main__": diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py index 29d38d60b1..fb91200e5e 100644 --- a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -5,7 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.runtime import RuntimeSettings _RUNTIMES = [("python", True), ("cpp", False)] @@ -25,7 +25,7 @@ def forward(self, x): def _compile_conv(strategy, *, use_python_runtime): - """Compile CudaGraphConvModel through the selected runtime with the given strategy.""" + """Compile CudaGraphConvModel with the given cuda_graph_strategy hint.""" model = CudaGraphConvModel().eval().cuda() inputs = [torch.randn(2, 3, 16, 16).cuda()] compiled = torchtrt.compile( @@ -35,7 +35,7 @@ def _compile_conv(strategy, *, use_python_runtime): enabled_precisions={torch.float32}, use_python_runtime=use_python_runtime, min_block_size=1, - cuda_graph_strategy=strategy, + runtime_settings=RuntimeSettings(cuda_graph_strategy=strategy), ) torch._dynamo.reset() return compiled, inputs @@ -46,8 +46,8 @@ def forward(self, x): return torch.relu(x) + 1.0 -def _compile_simple(**extra_kwargs): - """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" +def _compile_simple(*, runtime_settings=None): + """Compile SimpleModel with dynamic shapes and the Python runtime.""" model = SimpleModel().eval().cuda() inputs = [ torchtrt.Input( @@ -57,25 +57,20 @@ def _compile_simple(**extra_kwargs): dtype=torch.float32, ) ] - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, - } - kwargs.update(extra_kwargs) - compiled = torchtrt.compile(model, **kwargs) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + runtime_settings=runtime_settings, + ) torch._dynamo.reset() return compiled def _find_python_trt_engine(compiled): - """Walk the compiled graph module and return the Python ``TRTEngine`` instance. - - The C++ and Python runtimes are now both driven through ``TorchTensorRTModule``; - ``use_python_runtime=True`` causes ``module.engine`` to be a Python ``TRTEngine``. - """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine @@ -93,63 +88,47 @@ class TestCudaGraphStrategySetup(TestCase): """Tests that cuda_graph_strategy is correctly applied on TRT-RTX.""" def test_default_strategy_is_disabled(self): - import tensorrt as trt - compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) - self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.DISABLED, - ) + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") - def test_whole_graph_capture_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + def test_whole_graph_capture_strategy_via_compile_hint(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + engine.runtime_settings.cuda_graph_strategy, "whole_graph_capture" ) - def test_rtx_native_flag_set(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + def test_rtx_native_flag_tracks_strategy(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + ) engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) self.assertTrue(engine._rtx_native_cudagraphs) - def test_rtx_native_flag_disabled(self): - compiled = _compile_simple(cuda_graph_strategy="disabled") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertFalse(engine._rtx_native_cudagraphs) + compiled2 = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="disabled") + ) + engine2 = _find_python_trt_engine(compiled2) + self.assertFalse(engine2._rtx_native_cudagraphs) - def test_inference_with_each_strategy(self): - for strategy in ("disabled", "whole_graph_capture"): - with self.subTest(strategy=strategy): - compiled = _compile_simple(cuda_graph_strategy=strategy) - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone( - engine.context, - f"Execution context should be created for {strategy}", - ) - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_setting_in_compilation_settings(self): - for strategy in ("disabled", "whole_graph_capture"): - settings = CompilationSettings(cuda_graph_strategy=strategy) - self.assertEqual(settings.cuda_graph_strategy, strategy) - - def test_default_compilation_settings(self): - settings = CompilationSettings() - self.assertEqual(settings.cuda_graph_strategy, "disabled") + def test_runtime_cm_overrides_strategy(self): + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") + with torchtrt.runtime.set_cuda_graph_strategy(compiled, "whole_graph_capture"): + self.assertEqual( + engine.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + # Restored on exit. + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") @unittest.skipIf( @@ -166,17 +145,14 @@ def tearDown(self): torchtrt.runtime.set_cudagraphs_mode(False) def test_rtx_native_bypasses_manual_capture(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) - torchtrt.runtime.set_cudagraphs_mode(True) - - # Run inference a few times to ensure capture would have happened for _ in range(3): compiled(torch.randn(2, 3).cuda()) - - # Manual cudagraph should NOT have been recorded (RTX handles it natively) self.assertFalse( isinstance(engine.cudagraph, torch.cuda.CUDAGraph), "Manual CUDA graph should not be recorded when RTX native is active", @@ -185,28 +161,17 @@ def test_rtx_native_bypasses_manual_capture(self): def test_subgraph_mode_always_uses_rtx_native(self): """Even with cuda_graph_strategy=disabled, SUBGRAPH mode on RTX should override to RTX-native because manual capture is not safe.""" - compiled = _compile_simple(cuda_graph_strategy="disabled") + compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) - # Initially, _rtx_native_cudagraphs is False (disabled strategy) self.assertFalse(engine._rtx_native_cudagraphs) - torchtrt.runtime.set_cudagraphs_mode(True) - - # Run inference -- should trigger override to RTX-native for _ in range(3): compiled(torch.randn(2, 3).cuda()) - - # Should have been overridden to RTX-native self.assertTrue( engine._rtx_native_cudagraphs, "RTX-native should be enabled automatically in SUBGRAPH mode", ) - # Manual cudagraph should NOT have been recorded - self.assertFalse( - isinstance(engine.cudagraph, torch.cuda.CUDAGraph), - "Manual CUDA graph should not be recorded on RTX", - ) @unittest.skipIf( @@ -217,10 +182,11 @@ class TestMonolithicCapturability(TestCase): """Tests for _is_monolithic_capturable() and related logic.""" def test_lazy_strategy_not_monolithic_capturable(self): - """Lazy kernel specialization makes monolithic capture unsafe.""" compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="lazy", + runtime_settings=RuntimeSettings( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="lazy", + ), ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) @@ -228,163 +194,19 @@ def test_lazy_strategy_not_monolithic_capturable(self): self.assertFalse(engine._is_monolithic_capturable(stream)) def test_eager_strategy_monolithic_capturable(self): - """Eager strategy with capturable stream should be monolithic capturable.""" compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="eager", + runtime_settings=RuntimeSettings( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="eager", + ), ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) stream = torch.cuda.Stream() - # is_stream_capturable depends on engine properties. - # With eager strategy, the strategy check passes. - if engine.context.is_stream_capturable(stream.cuda_stream): - self.assertTrue(engine._is_monolithic_capturable(stream)) - - def test_none_strategy_monolithic_capturable(self): - """None strategy (always fallback) should be monolithic capturable.""" - compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="none", - ) - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - stream = torch.cuda.Stream() - if engine.context.is_stream_capturable(stream.cuda_stream): - self.assertTrue(engine._is_monolithic_capturable(stream)) - - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Context recreation tests require TensorRT-RTX", -) -class TestContextRecreation(TestCase): - """Tests for _enable_rtx_native_cudagraphs() context recreation.""" - - def test_enable_rtx_native_recreates_context(self): - """Calling _enable_rtx_native_cudagraphs recreates the execution context.""" - import tensorrt as trt - - compiled = _compile_simple(cuda_graph_strategy="disabled") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertFalse(engine._rtx_native_cudagraphs) - - old_context_id = id(engine.context) - engine._enable_rtx_native_cudagraphs() - - self.assertTrue(engine._rtx_native_cudagraphs) - self.assertNotEqual( - id(engine.context), - old_context_id, - "Context should be recreated", - ) - self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, - ) - - def test_explicit_whole_graph_capture_no_override_needed(self): - """With explicit whole_graph_capture, SUBGRAPH mode should not - need to override (already RTX-native).""" - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertTrue(engine._rtx_native_cudagraphs) - - old_context_id = id(engine.context) - - torchtrt.runtime.set_cudagraphs_mode(True) - compiled(torch.randn(2, 3).cuda()) - torchtrt.runtime.set_cudagraphs_mode(False) - - # Context should NOT have been recreated (was already RTX-native) - self.assertEqual( - id(engine.context), - old_context_id, - "Context should not be recreated if already RTX-native", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Cudagraph mode toggle tests require TensorRT-RTX", -) -class TestCudagraphModeToggle(TestCase): - """Tests for toggling cudagraph mode with RTX-native.""" - - def setUp(self): - torchtrt.runtime.set_cudagraphs_mode(False) - - def tearDown(self): - torchtrt.runtime.set_cudagraphs_mode(False) - - def test_cudagraphs_off_after_rtx_native_override(self): - """After RTX-native override, disabling cudagraphs should still - produce correct results (RTX-native continues transparently).""" - compiled = _compile_simple(cuda_graph_strategy="disabled") - - torchtrt.runtime.set_cudagraphs_mode(True) - compiled(torch.randn(2, 3).cuda()) # triggers override - - torchtrt.runtime.set_cudagraphs_mode(False) - - # Should still work -- RTX-native is transparent - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_no_cudagraphs_with_whole_graph_capture(self): - """With cuda_graph_strategy='whole_graph_capture' but no - set_cudagraphs_mode, RTX-native runs transparently.""" - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - self.assertTrue(engine._rtx_native_cudagraphs) - - # No set_cudagraphs_mode(True) -- RTX-native still active transparently - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_toggle_on_off_on(self): - """Toggle cudagraphs on -> off -> on, verify correctness each time.""" - compiled = _compile_simple(cuda_graph_strategy="disabled") - inp = torch.randn(2, 3).cuda() - - # Phase 1: on - torchtrt.runtime.set_cudagraphs_mode(True) - out1 = compiled(inp) - self.assertEqual(out1.shape, (2, 3)) - - # Phase 2: off - torchtrt.runtime.set_cudagraphs_mode(False) - out2 = compiled(inp) - self.assertEqual(out2.shape, (2, 3)) - - # Phase 3: on again - torchtrt.runtime.set_cudagraphs_mode(True) - out3 = compiled(inp) - self.assertEqual(out3.shape, (2, 3)) - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestCudaGraphStrategyNonRTX(TestCase): - """Tests that the setting is ignored on non-RTX builds.""" - - def test_setting_ignored_on_non_rtx(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - if engine is not None: - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - self.assertFalse(engine._rtx_native_cudagraphs) - output = compiled(torch.randn(2, 3).cuda()) - self.assertEqual(output.shape, (2, 3)) + # The exact result depends on engine stream-capturability; just verify + # the lazy-gating doesn't fire on eager. + with torch.cuda.stream(stream): + _ = engine._is_monolithic_capturable(stream) _STRATEGY_RUNTIME_MATRIX = [ @@ -401,65 +223,23 @@ def test_setting_ignored_on_non_rtx(self): class TestCudaGraphStrategyInference(TestCase): """End-to-end: compile + infer with each strategy on both runtime paths.""" - def tearDown(self): - torchtrt.runtime.set_cudagraphs_mode(False) - @parameterized.expand(_STRATEGY_RUNTIME_MATRIX) def test_strategy_inference(self, strategy, _runtime_name, use_python_runtime): _skip_if_cpp_unavailable(self, use_python_runtime) compiled, inputs = _compile_conv( strategy, use_python_runtime=use_python_runtime ) - y = compiled(*[inp.clone() for inp in inputs]) + y = compiled(*inputs) self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) self.assertTrue(torch.isfinite(y).all().item()) - @parameterized.expand(_RUNTIMES) - def test_whole_graph_capture_with_subgraph_cudagraphs( - self, _name, use_python_runtime - ): - """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - compiled, inputs = _compile_conv( - "whole_graph_capture", use_python_runtime=use_python_runtime - ) - torchtrt.runtime.set_cudagraphs_mode(True) - y = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - @parameterized.expand(_RUNTIMES) - def test_repeated_inference(self, _name, use_python_runtime): - """Repeated inference exercises the RTX-native capture/replay path.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - compiled, inputs = _compile_conv( - "whole_graph_capture", use_python_runtime=use_python_runtime - ) - ref = compiled(*[inp.clone() for inp in inputs]) - for _ in range(4): - out = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(out.shape, ref.shape) - self.assertTrue(torch.isfinite(out).all().item()) - class TestCudaGraphStrategyInvalidValue(TestCase): - """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + """Invalid strategy names are rejected at ``RuntimeSettings.__post_init__``.""" - @parameterized.expand(_RUNTIMES) - def test_invalid_strategy_raises(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model = CudaGraphConvModel().eval().cuda() - inputs = [torch.randn(2, 3, 16, 16).cuda()] - with self.assertRaises((ValueError, RuntimeError)): - torchtrt.compile( - model, - ir="dynamo", - inputs=inputs, - enabled_precisions={torch.float32}, - use_python_runtime=use_python_runtime, - min_block_size=1, - cuda_graph_strategy="not_a_real_strategy", - ) + def test_invalid_strategy_raises_at_construction(self): + with self.assertRaises(ValueError): + RuntimeSettings(cuda_graph_strategy="not_a_real_strategy") if __name__ == "__main__": diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 408fb159fc..e126aee4d5 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -5,7 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.runtime import RuntimeSettings _STRATEGIES = [("lazy",), ("eager",), ("none",)] @@ -34,7 +34,7 @@ def _skip_if_cpp_unavailable(testcase, use_python_runtime): def _compile_dynamic_conv(strategy, *, use_python_runtime): - """Compile DynamicConvModel through the selected runtime with the given strategy.""" + """Compile DynamicConvModel with the given strategy as a compile-time hint.""" model = DynamicConvModel().eval().cuda() inp = torchtrt.Input( min_shape=(1, 3, 16, 16), @@ -49,14 +49,16 @@ def _compile_dynamic_conv(strategy, *, use_python_runtime): enabled_precisions={torch.float32}, use_python_runtime=use_python_runtime, min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy=strategy, + ), ) torch._dynamo.reset() return compiled -def _compile_simple(**extra_kwargs): - """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" +def _compile_simple(*, runtime_settings=None): + """Compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() inputs = [ torchtrt.Input( @@ -66,14 +68,14 @@ def _compile_simple(**extra_kwargs): dtype=torch.float32, ) ] - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "use_python_runtime": True, - "min_block_size": 1, - } - kwargs.update(extra_kwargs) - compiled = torchtrt.compile(model, **kwargs) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=True, + min_block_size=1, + runtime_settings=runtime_settings, + ) torch._dynamo.reset() return compiled @@ -101,94 +103,79 @@ class TestDynamicShapesKernelStrategySetup(TestCase): """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" def test_default_strategy_is_lazy(self): - import tensorrt as trt - compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.LAZY, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", ) - def test_eager_strategy(self): - import tensorrt as trt - + def test_eager_strategy_via_compile_hint(self): compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy="eager" + ) ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.EAGER, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "eager", ) - def test_none_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") + def test_none_strategy_via_compile_hint(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy="none" + ) + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.NONE, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "none", + ) + + def test_runtime_cm_overrides_strategy(self): + """`set_dynamic_shapes_kernel_strategy` CM overrides the active strategy.""" + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", + ) + with torchtrt.runtime.set_dynamic_shapes_kernel_strategy(compiled, "eager"): + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "eager", + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + # Restored on exit. + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", ) def test_context_created_with_each_strategy(self): for strategy in ("lazy", "eager", "none"): with self.subTest(strategy=strategy): compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy=strategy + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy=strategy + ) ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone( engine.context, f"Execution context should be created for {strategy}", ) - # Test inference with multiple dynamic batch sizes for bs in (1, 2, 4): output = compiled(torch.randn(bs, 3).cuda()) self.assertEqual(output.shape, (bs, 3)) - def test_setting_in_compilation_settings(self): - for strategy in ("lazy", "eager", "none"): - settings = CompilationSettings( - dynamic_shapes_kernel_specialization_strategy=strategy - ) - self.assertEqual( - settings.dynamic_shapes_kernel_specialization_strategy, strategy - ) - - def test_default_compilation_settings(self): - settings = CompilationSettings() - self.assertEqual(settings.dynamic_shapes_kernel_specialization_strategy, "lazy") - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestDynamicShapesKernelStrategyNonRTX(TestCase): - """Tests that the setting is ignored on non-RTX builds.""" - - def test_setting_ignored_on_non_rtx(self): - compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - engine = _find_python_trt_engine(compiled) - if engine is not None: - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - # Inference should still work - output = compiled(torch.randn(2, 3).cuda()) - self.assertEqual(output.shape, (2, 3)) - _STRATEGY_RUNTIME_MATRIX = [ (strategy, runtime_name, use_python_runtime) @@ -227,26 +214,11 @@ def test_dynamic_shape_with_eager(self, _name, use_python_runtime): class TestDynamicShapesKernelStrategyInvalidValue(TestCase): - """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + """Invalid strategy names are rejected at ``RuntimeSettings.__post_init__``.""" - @parameterized.expand(_RUNTIMES) - def test_invalid_strategy_raises(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model = DynamicConvModel().eval().cuda() - inp = torchtrt.Input( - min_shape=(1, 3, 16, 16), - opt_shape=(2, 3, 16, 16), - max_shape=(4, 3, 16, 16), - dtype=torch.float32, - ) - with self.assertRaises((ValueError, RuntimeError)): - torchtrt.compile( - model, - ir="dynamo", - inputs=[inp], - enabled_precisions={torch.float32}, - use_python_runtime=use_python_runtime, - min_block_size=1, + def test_invalid_strategy_raises_at_construction(self): + with self.assertRaises(ValueError): + RuntimeSettings( dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", ) diff --git a/tests/py/dynamo/runtime/test_004_runtime_settings.py b/tests/py/dynamo/runtime/test_004_runtime_settings.py new file mode 100644 index 0000000000..b5ff008c27 --- /dev/null +++ b/tests/py/dynamo/runtime/test_004_runtime_settings.py @@ -0,0 +1,176 @@ +"""Whitebox tests for the RuntimeSettings data model + dispatch.""" + +import dataclasses +import unittest + +import torch +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.runtime import ( + RuntimeCacheHandle, + RuntimeSettings, + runtime_config, +) + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(*, runtime_settings=None, use_python_runtime=True): + model = SimpleModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=use_python_runtime, + min_block_size=1, + runtime_settings=runtime_settings, + ) + torch._dynamo.reset() + return compiled + + +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +class TestRuntimeSettingsDataModel(TestCase): + """Pure dataclass behavior; no engine compile required.""" + + def test_defaults_are_valid(self): + rs = RuntimeSettings() + self.assertEqual(rs.dynamic_shapes_kernel_specialization_strategy, "lazy") + self.assertEqual(rs.cuda_graph_strategy, "disabled") + self.assertIsNone(rs.runtime_cache) + + def test_invalid_ds_strategy_raises_at_post_init(self): + with self.assertRaises(ValueError): + RuntimeSettings(dynamic_shapes_kernel_specialization_strategy="bogus") + + def test_invalid_cg_strategy_raises_at_post_init(self): + with self.assertRaises(ValueError): + RuntimeSettings(cuda_graph_strategy="bogus") + + def test_frozen(self): + rs = RuntimeSettings() + with self.assertRaises(dataclasses.FrozenInstanceError): + rs.cuda_graph_strategy = "whole_graph_capture" + + def test_merge_overrides(self): + rs = RuntimeSettings() + new = rs.merge(cuda_graph_strategy="whole_graph_capture") + self.assertEqual(new.cuda_graph_strategy, "whole_graph_capture") + # Original unchanged (frozen + replace). + self.assertEqual(rs.cuda_graph_strategy, "disabled") + + def test_merge_unknown_key_raises(self): + rs = RuntimeSettings() + with self.assertRaises(TypeError): + rs.merge(not_a_real_field=True) + + def test_equality_compares_all_fields(self): + a = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + b = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + c = RuntimeSettings(cuda_graph_strategy="disabled") + self.assertEqual(a, b) + self.assertNotEqual(a, c) + + def test_runtime_cache_as_path_string(self): + rs = RuntimeSettings(runtime_cache="/tmp/whatever.bin") + self.assertEqual(rs.runtime_cache, "/tmp/whatever.bin") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "RuntimeSettings dispatch is exercised on TRT-RTX", +) +class TestRuntimeSettingsCompileTimeHint(TestCase): + """Verify the compile-time hint primes the engine without a CM.""" + + def test_compile_hint_sets_engine_settings(self): + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple(runtime_settings=rs) + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + + def test_runtime_config_cm_restores_on_exit(self): + compiled = _compile_simple() + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + mod = next( + m for _, m in compiled.named_modules() if isinstance(m, TorchTensorRTModule) + ) + prior = mod.runtime_settings + with runtime_config(compiled, cuda_graph_strategy="whole_graph_capture"): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + self.assertEqual(mod.runtime_settings, prior) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Multi-target tests require TRT-RTX", +) +class TestMultiTargetRuntimeConfig(TestCase): + """`runtime_config([a, b], ...)` applies to engines under both targets.""" + + def test_multi_target_runtime_config(self): + model_a = _compile_simple() + model_b = _compile_simple() + with runtime_config( + [model_a, model_b], cuda_graph_strategy="whole_graph_capture" + ) as (m_a, m_b): + self.assertIs(m_a, model_a) + self.assertIs(m_b, model_b) + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for target in (model_a, model_b): + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, + "whole_graph_capture", + ) + + +class TestRuntimeConfigInvalidKey(TestCase): + """Typo in a CM key should raise at construction, not silently no-op.""" + + def test_unknown_kwarg_raises(self): + # Use a Module that's not a TorchTensorRTModule -- we just need the + # CM constructor to run; __enter__ won't find any engines. + target = torch.nn.Linear(3, 3) + with self.assertRaises(TypeError): + runtime_config(target, not_a_real_field=True) + + +if __name__ == "__main__": + run_tests() From 415bea73d60143736113c54298217738a6cff89b Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 11:23:05 -0700 Subject: [PATCH 2/7] fix(runtime): support None and disk-backed runtime_cache for C++ runtime Two follow-up bugs exposed by the cross-runtime test parameterization on the C++ engine path: 1. ``torch.classes.tensorrt.Engine.update_runtime_settings(...)`` rejected Python ``None`` for the ``RuntimeCacheHandle`` argument because TorchBind does not auto-convert ``None`` to a null ``c10::intrusive_ptr``. Switch the signature to ``c10::optional>`` so the default ``runtime_cache=None`` case round-trips cleanly. 2. ``RuntimeSettings(runtime_cache="/some/path")`` only auto-saved to disk on engine destruction for the Python runtime (via ``_TRTEngine.__del__``). The C++ engine had no equivalent saver and the IRuntimeCache it materialized internally wasn't accessible from Python. Make the cpp path symmetric: - Expose ``serialize() -> at::Tensor`` / ``deserialize(at::Tensor)`` / ``has_cache()`` on the torchbind ``RuntimeCacheHandle`` class. ``at::Tensor`` of uint8 is used instead of ``std::string`` because TorchBind forces ``std::string`` through Python ``str`` (UTF-8) and serialized cache bytes are not valid UTF-8. - In ``TorchTensorRTModule.setup_engine`` (cpp branch), pre-materialize a torchbind handle when ``runtime_cache`` is a path string, store it on the module, and substitute it into ``_runtime_settings`` so the dispatch passes the same handle through. - Add ``_load_cpp_implicit_cache`` / ``_save_cpp_implicit_cache`` and a module ``__del__`` that mirrors the Python ``_TRTEngine`` saver, with ``filelock`` + atomic-rename semantics. - Teach ``_to_torchbind_handle`` to pass an already-torchbind ``torch.ScriptObject`` through unchanged. All cpp + python runtime tests pass on TRT-RTX 1.5: test_004 (12/12), test_000 (10/10), test_001 dynamic_shapes (14/14), test_001 cuda_graph (13/13). --- core/runtime/register_jit_hooks.cpp | 58 +++++++++++- .../dynamo/runtime/_TorchTensorRTModule.py | 94 +++++++++++++++++++ py/torch_tensorrt/runtime/_runtime_cache.py | 8 +- 3 files changed, 155 insertions(+), 5 deletions(-) diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e4294ab754..02bf244998 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -23,7 +23,57 @@ static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = torch::class_("tensorrt", "RuntimeCacheHandle") .def(torch::init()) .def("path", &RuntimeCacheHandle::path) - .def("set_path", &RuntimeCacheHandle::set_path); + .def("set_path", &RuntimeCacheHandle::set_path) + // Expose the underlying IRuntimeCache bytes to Python so the Python- + // side save/load logic can persist them under filelock. Returns an + // empty uint8 tensor if the cache hasn't been materialized yet. + // + // We return ``at::Tensor`` rather than ``std::string`` because TorchBind + // forces ``std::string`` to round-trip through Python ``str`` (UTF-8) + // and serialized cache bytes are not valid UTF-8. + .def( + "serialize", + [](const c10::intrusive_ptr& self) -> at::Tensor { +#ifdef TRT_MAJOR_RTX + auto opts = at::TensorOptions().dtype(at::kByte); + if (!self->cache) { + return at::empty({0}, opts); + } + auto host_mem = make_trt(self->cache->serialize()); + if (!host_mem) { + return at::empty({0}, opts); + } + auto tensor = at::empty({static_cast(host_mem->size())}, opts); + std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size()); + return tensor; +#else + return at::empty({0}, at::TensorOptions().dtype(at::kByte)); +#endif + }) + // Deserialize bytes loaded from disk into the underlying IRuntimeCache. + // Expects a uint8 ``at::Tensor``. No-op for empty input or if the + // IRuntimeCache hasn't been materialized yet. + .def( + "deserialize", + [](const c10::intrusive_ptr& self, at::Tensor data) -> void { +#ifdef TRT_MAJOR_RTX + if (data.numel() == 0 || !self->cache) { + return; + } + auto contig = data.contiguous().to(at::kCPU); + self->cache->deserialize(contig.data_ptr(), static_cast(contig.numel())); +#else + (void)data; +#endif + }) + // True iff an engine has populated the underlying IRuntimeCache. + .def("has_cache", [](const c10::intrusive_ptr& self) -> bool { +#ifdef TRT_MAJOR_RTX + return self->cache != nullptr; +#else + return false; +#endif + }); // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { @@ -64,11 +114,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = [](const c10::intrusive_ptr& self, std::string const& dynamic_shapes_kernel_specialization_strategy, std::string const& cuda_graph_strategy, - c10::intrusive_ptr runtime_cache) -> void { + c10::optional> runtime_cache) -> void { + // `c10::optional` lets TorchBind accept Python `None` here. We + // translate to a (possibly null) intrusive_ptr inside the struct. RuntimeSettings rs; rs.dynamic_shapes_kernel_specialization_strategy = dynamic_shapes_kernel_specialization_strategy; rs.cuda_graph_strategy = cuda_graph_strategy; - rs.runtime_cache = std::move(runtime_cache); + rs.runtime_cache = runtime_cache.has_value() ? std::move(*runtime_cache) : nullptr; self->update_runtime_settings(std::move(rs)); }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 5483f16347..62bafb6bd4 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -3,7 +3,9 @@ import base64 import copy import logging +import os import pickle +import shutil from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch @@ -344,8 +346,36 @@ def setup_engine(self) -> None: else: self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) self.execute_engine_op = torch.ops.tensorrt.execute_engine + # If the compile-time hint was a path string, pre-materialize a + # torchbind RuntimeCacheHandle here so we (a) own a Python-side + # reference that survives until the module is collected, and (b) + # can save the cache to disk in __del__ (the C++ engine has no + # Python __del__; file I/O lives on the Python side). Substitute + # the handle for the string in the settings so the dispatch + # passes the same handle through to TorchBind. + rc = self._runtime_settings.runtime_cache + if isinstance(rc, str) and rc: + handle = torch.classes.tensorrt.RuntimeCacheHandle(rc) + self._cpp_implicit_cache_handle = handle + self._cpp_implicit_cache_path = rc + self._runtime_settings = self._runtime_settings.merge( + runtime_cache=handle + ) + # Pre-load any existing on-disk cache so the engine sees + # warm contents on first inference. The first engine attach + # materializes the IRuntimeCache via createRuntimeCache(). + self._cpp_implicit_handle_pending_load = True + else: + self._cpp_implicit_cache_handle = None + self._cpp_implicit_cache_path = None + self._cpp_implicit_handle_pending_load = False # Apply runtime settings to the C++ engine (no-op if defaults). self._dispatch_runtime_settings_to_engine(self._runtime_settings) + # After dispatch the IRuntimeCache exists inside the handle; load + # the on-disk bytes (filelocked) so they're picked up on first run. + if self._cpp_implicit_handle_pending_load: + self._load_cpp_implicit_cache() + self._cpp_implicit_handle_pending_load = False # requires_native_multidevice is set by the C++ constructor from the serialized REQUIRES_NATIVE_MULTIDEVICE_IDX field. if self.engine.requires_native_multidevice: @@ -377,6 +407,70 @@ def setup_engine(self) -> None: # code cache and isn't reachable via module tree walking. register_md_engine(self.engine) + def _load_cpp_implicit_cache(self) -> None: + """Deserialize on-disk cache bytes into the torchbind handle. + + Mirrors :py:meth:`RuntimeCacheHandle.load` for the C++ runtime path. + No-op on first run (file absent) or if the IRuntimeCache hasn't been + materialized yet inside the C++ engine. + """ + handle = getattr(self, "_cpp_implicit_cache_handle", None) + path = getattr(self, "_cpp_implicit_cache_path", None) + if handle is None or not path or not handle.has_cache(): + return + if not os.path.exists(path): + return + try: + from filelock import FileLock + + with FileLock(path + ".lock").acquire(timeout=10): + with open(path, "rb") as f: + data = f.read() + if data: + # The torchbind `deserialize` takes a uint8 tensor; we wrap the + # raw bytes via ``frombuffer`` for a zero-copy view. + tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) + handle.deserialize(tensor) + logger.debug(f"Loaded runtime cache from {path} ({len(data)} bytes)") + except Exception as e: + logger.debug(f"Failed to load runtime cache from {path}: {e}") + + def _save_cpp_implicit_cache(self) -> None: + """Serialize the torchbind handle's IRuntimeCache to disk under filelock. + + Called from __del__. Suppresses all exceptions because __del__ may + run during interpreter shutdown when imports / filesystem ops can + fail in unpredictable ways. + """ + handle = getattr(self, "_cpp_implicit_cache_handle", None) + path = getattr(self, "_cpp_implicit_cache_path", None) + if handle is None or not path: + return + try: + if not handle.has_cache(): + return + tensor = handle.serialize() + if tensor.numel() == 0: + return + data = bytes(tensor.cpu().contiguous().numpy()) + from filelock import FileLock + + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + tmp = path + ".tmp" + with FileLock(path + ".lock").acquire(timeout=10): + with open(tmp, "wb") as f: + f.write(data) + shutil.move(tmp, path) + logger.debug(f"Saved runtime cache to {path} ({len(data)} bytes)") + except Exception: + # Best-effort: never raise out of __del__. + pass + + def __del__(self) -> None: + self._save_cpp_implicit_cache() + def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) dumped_metadata = pickle.dumps(metadata) diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index 15a01ed0c0..8e977cd3cc 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -231,14 +231,16 @@ def runtime_cache( # at dispatch time the Python module converts to/from the torchbind class as # needed (see ``TorchTensorRTModule.set_runtime_settings``). def _to_torchbind_handle( - rc: Union[None, str, "RuntimeCacheHandle"], + rc: Union[None, str, "RuntimeCacheHandle", Any], ) -> Any: """Convert a Python-side ``runtime_cache`` value to a torchbind handle suitable for ``torch.classes.tensorrt.Engine.update_runtime_settings(...)``. Returns ``None`` if no runtime cache is requested. Raises if the C++ runtime isn't loaded (caller shouldn't dispatch to a C++ engine in that - case anyway). + case anyway). Already-torchbind handles (``torch.ScriptObject``) are passed + through unchanged so callers can pre-stash a handle on the module and + share it across dispatch calls. """ if rc is None: return None @@ -247,5 +249,7 @@ def _to_torchbind_handle( "torch_tensorrt C++ runtime is not available; cannot construct " "torch.classes.tensorrt.RuntimeCacheHandle" ) + if isinstance(rc, torch.ScriptObject): + return rc path = rc if isinstance(rc, str) else rc.path return torch.classes.tensorrt.RuntimeCacheHandle(path) From 0d360de8ecb8fbdbaf4e832bcf7cb68f288531d3 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 19:14:26 -0700 Subject: [PATCH 3/7] refactor(runtime): TRTRuntimeConfig owns RuntimeSettings; unified RuntimeCacheHandle lifecycle Structural cleanup on top of the v3 work (no observable behavior change). C++ side -------- ``RuntimeSettings`` migrates from a ``TRTEngine`` member to a ``TRTRuntimeConfig`` member -- the value-type now lives with its primary consumer (the IRuntimeConfig builder). ``TRTRuntimeConfig`` gains ``set_settings()`` (the diff-and-invalidate primitive) and turns the static ``uses_internal_capture`` / ``is_monolithic_capturable`` helpers into instance methods so callers do not need to pass settings around. ``TRTEngine::runtime_settings()`` forwards through. Python side ----------- Introduces a Python ``TRTRuntimeConfig`` class mirroring the C++ struct. ``_TRTEngine`` drops its three legacy fields (``runtime_config``, ``runtime_settings``, ``_implicit_cache_handle``) for a single ``self._trt_runtime_config`` member; ``_create_execution_context`` / ``update_runtime_settings`` / ``_is_monolithic_capturable`` / ``_enable_rtx_native_cudagraphs`` all delegate. Every ``ENABLED_FEATURES.tensorrt_rtx`` branch related to runtime-mode controls is absorbed into the shim, so engine and module call sites stay uniform across TRT and TRT-RTX builds. Following the project's grouping convention, ``py/torch_tensorrt/runtime/_runtime_settings.py`` is merged into ``_runtime_config.py``; that file now holds ``RuntimeSettings``, the new ``TRTRuntimeConfig``, the existing ``runtime_config()`` CM, and its factory. Imports across the tree are repointed. RuntimeCacheHandle ownership model ---------------------------------- Save-on-destruction moves from the two engine-side ``__del__`` paths (``_TRTEngine.close()`` for Python runtime, ``TorchTensorRTModule.__del__`` for cpp runtime) onto ``RuntimeCacheHandle.__del__`` itself, gated by a new ``autosave_on_del`` flag. The flag is set by ownership context: * Engine-implicit handles (created from a path-string compile-time hint) get ``autosave_on_del=True`` -- no other Python object holds them, so the destructor is the only save opportunity. * The ``runtime_cache(target, path)`` CM uses ``autosave_on_del=False`` on the handle it constructs; its ``__exit__`` saves explicitly. * Hand-built handles default to ``autosave_on_del=False`` so save timing stays under the user's control. The handle additionally accepts a ``torchbind_handle`` sibling so the same Python object can wrap either a ``trt.IRuntimeCache`` (Python rt) or a ``torch.classes.tensorrt.RuntimeCacheHandle`` (cpp rt); ``save`` / ``load`` source bytes from whichever is populated. The cpp-runtime helpers on ``TorchTensorRTModule`` (``_load_cpp_implicit_cache``, ``_save_cpp_implicit_cache``, ``__del__``) and the duplicate save logic in ``_TRTEngine.close()`` are removed; both runtimes funnel through the single ``RuntimeCacheHandle.__del__`` path. Tests ----- test_000 grows two new tests asserting the new contract: * ``test_cm_does_not_double_save_on_rc_gc`` -- only one save fires per CM block even after ``rc`` is GC'd. * ``test_user_built_handle_no_autosave_by_default`` -- hand-built handles do not autosave on GC. All 51 runtime tests pass on the refactored design (test_004 12/12, test_000 12/12, test_001 ds 14/14, test_001 cg 13/13). --- core/runtime/TRTEngine.cpp | 18 +- core/runtime/TRTEngine.h | 16 +- core/runtime/TRTRuntimeConfig.cpp | 52 +-- core/runtime/TRTRuntimeConfig.h | 66 ++-- core/runtime/execute_engine.cpp | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 2 +- .../dynamo/conversion/_conversion.py | 2 +- .../dynamo/runtime/_TRTEngine.py | 279 +++++---------- .../dynamo/runtime/_TorchTensorRTModule.py | 113 ++---- py/torch_tensorrt/runtime/__init__.py | 3 +- py/torch_tensorrt/runtime/_runtime_cache.py | 135 +++++--- py/torch_tensorrt/runtime/_runtime_config.py | 323 +++++++++++++++++- .../runtime/_runtime_settings.py | 98 ------ .../dynamo/runtime/test_000_runtime_cache.py | 49 +++ 14 files changed, 644 insertions(+), 514 deletions(-) delete mode 100644 py/torch_tensorrt/runtime/_runtime_settings.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 1110a53b8c..bf0fb81897 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -149,7 +149,7 @@ TRTEngine::TRTEngine( const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, RuntimeSettings runtime_settings) { - this->runtime_settings_ = std::move(runtime_settings); + this->runtime_cfg = TRTRuntimeConfig(std::move(runtime_settings)); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -464,7 +464,7 @@ std::string TRTEngine::to_str() const { ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; - ss << runtime_settings_.to_str(); + ss << runtime_cfg.settings().to_str(); // clang-format on return ss.str(); } @@ -658,31 +658,27 @@ void TRTEngine::release_nccl_comm() { #endif // ENABLE_TRT_NCCL_COLLECTIVES bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { - return TRTRuntimeConfig::is_monolithic_capturable(runtime_settings_, has_dynamic_inputs, exec_ctx.get(), stream); + return runtime_cfg.is_monolithic_capturable(has_dynamic_inputs, exec_ctx.get(), stream); } void TRTEngine::disable_rtx_native_cudagraphs() { #ifdef TRT_MAJOR_RTX - if (runtime_settings_.cuda_graph_strategy == "disabled") { + if (runtime_cfg.settings().cuda_graph_strategy == "disabled") { return; } LOG_WARNING( "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " << name << " for the remainder of its lifetime."); - RuntimeSettings new_settings = runtime_settings_; + RuntimeSettings new_settings = runtime_cfg.settings(); new_settings.cuda_graph_strategy = "disabled"; update_runtime_settings(std::move(new_settings)); #endif } void TRTEngine::update_runtime_settings(RuntimeSettings new_settings) { - if (new_settings == runtime_settings_) { + if (!runtime_cfg.set_settings(std::move(new_settings))) { return; } - runtime_settings_ = std::move(new_settings); - // Force the next ensure_initialized to rebuild the IRuntimeConfig with the new - // strategy values + (possibly) the new attached cache handle. - runtime_cfg.reset(); recreate_execution_context(); // Existing recreate sites set runtime_states.context_changed for cudagraph // re-record; do the same here so a settings flip inside an active CM forces @@ -694,7 +690,7 @@ void TRTEngine::recreate_execution_context() { const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; - exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), runtime_settings_, allocation_strategy); + exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index d3307c902f..9151bec374 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -279,19 +279,19 @@ struct TRTEngine : torch::CustomClassHolder { void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); - // Live IRuntimeConfig wrapper. Settings are sourced from `runtime_settings_` at - // every (re)build. On non-RTX or pre-10.11 TRT this is essentially empty. + // Owns the canonical `RuntimeSettings` plus the live `IRuntimeConfig` derived + // from them. The engine forwards `runtime_settings()` and + // `update_runtime_settings()` here -- there is no separate settings field on + // the engine. TRTRuntimeConfig runtime_cfg; - // Current user-facing runtime settings. Initialized from the constructor's - // `runtime_settings` param; mutated by `update_runtime_settings`. - RuntimeSettings runtime_settings_; [[nodiscard]] RuntimeSettings const& runtime_settings() const noexcept { - return runtime_settings_; + return runtime_cfg.settings(); } - // Apply new runtime settings. Fast-paths on equality. On change, rebuilds the - // IRuntimeConfig from the new settings and recreates the execution context. + // Apply new runtime settings. Fast-paths on equality (via + // `TRTRuntimeConfig::set_settings`). On change, rebuilds the + // `IRuntimeConfig` from the new settings and recreates the execution context. void update_runtime_settings(RuntimeSettings new_settings); // Whether the engine has any input binding with a dynamic dimension. Computed diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 9c402bc9c2..dafdcb2a2b 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "core/runtime/RuntimeSettings.h" #include "core/util/prelude.h" @@ -40,9 +41,18 @@ namespace { } // namespace -void TRTRuntimeConfig::ensure_initialized( - TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine, - TORCHTRT_UNUSED RuntimeSettings const& rs) { +bool TRTRuntimeConfig::set_settings(RuntimeSettings new_settings) { + if (new_settings == settings_) { + return false; + } + settings_ = std::move(new_settings); + // Invalidate the live IRuntimeConfig so the next `ensure_initialized` rebuilds + // with the new strategy values + cache attachment. + reset(); + return true; +} + +void TRTRuntimeConfig::ensure_initialized(TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine) { #ifdef TRT_HAS_IRUNTIME_CONFIG if (!config) { TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); @@ -55,13 +65,14 @@ void TRTRuntimeConfig::ensure_initialized( // RuntimeCacheHandle. The Python TRTEngine side creates an implicit // handle from a path string and passes it in via the handle; without // an explicit user opt-in we leave the IRuntimeConfig cache-less. - if (rs.runtime_cache) { - if (!rs.runtime_cache->cache) { - rs.runtime_cache->cache = make_trt(config->createRuntimeCache()); + if (settings_.runtime_cache) { + if (!settings_.runtime_cache->cache) { + settings_.runtime_cache->cache = make_trt(config->createRuntimeCache()); TORCHTRT_CHECK( - rs.runtime_cache->cache.get() != nullptr, "Failed to create IRuntimeCache for shared RuntimeCacheHandle"); + settings_.runtime_cache->cache.get() != nullptr, + "Failed to create IRuntimeCache for shared RuntimeCacheHandle"); } - if (config->setRuntimeCache(*rs.runtime_cache->cache)) { + if (config->setRuntimeCache(*settings_.runtime_cache->cache)) { LOG_DEBUG("Attached external IRuntimeCache to IRuntimeConfig."); } else { LOG_WARNING("Failed to attach IRuntimeCache to IRuntimeConfig; cache will be unused."); @@ -71,11 +82,12 @@ void TRTRuntimeConfig::ensure_initialized( } config->setDynamicShapesKernelSpecializationStrategy( - to_trt_ds_strategy(rs.dynamic_shapes_kernel_specialization_strategy)); + to_trt_ds_strategy(settings_.dynamic_shapes_kernel_specialization_strategy)); LOG_DEBUG( - "Dynamic shapes kernel specialization strategy set to " << rs.dynamic_shapes_kernel_specialization_strategy); + "Dynamic shapes kernel specialization strategy set to " + << settings_.dynamic_shapes_kernel_specialization_strategy); - if (!config->setCudaGraphStrategy(to_trt_cg_strategy(rs.cuda_graph_strategy))) { + if (!config->setCudaGraphStrategy(to_trt_cg_strategy(settings_.cuda_graph_strategy))) { LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); } #endif @@ -90,9 +102,8 @@ void TRTRuntimeConfig::reset() { std::shared_ptr TRTRuntimeConfig::create_execution_context( nvinfer1::ICudaEngine* cuda_engine, - RuntimeSettings const& rs, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy) { - ensure_initialized(cuda_engine, rs); + ensure_initialized(cuda_engine); #ifdef TRT_HAS_IRUNTIME_CONFIG config->setExecutionContextAllocationStrategy(allocation_strategy); return make_trt(cuda_engine->createExecutionContext(config.get())); @@ -102,24 +113,21 @@ std::shared_ptr TRTRuntimeConfig::create_execution_ #endif } -bool TRTRuntimeConfig::uses_internal_capture( - TORCHTRT_UNUSED RuntimeSettings const& rs, - TORCHTRT_UNUSED bool cudagraphs_enabled) noexcept { +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const noexcept { #ifdef TRT_MAJOR_RTX // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled // strategy is set, or when subgraph cudagraphs are enabled globally. In both // cases the caller should skip its manual at::cuda::CUDAGraph wrapper. - return rs.cuda_graph_strategy != "disabled" || cudagraphs_enabled; + return settings_.cuda_graph_strategy != "disabled" || cudagraphs_enabled; #else return false; #endif } bool TRTRuntimeConfig::is_monolithic_capturable( - TORCHTRT_UNUSED RuntimeSettings const& rs, TORCHTRT_UNUSED bool has_dynamic_inputs, TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, - TORCHTRT_UNUSED cudaStream_t stream) noexcept { + TORCHTRT_UNUSED cudaStream_t stream) const noexcept { #ifdef TRT_MAJOR_RTX TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); if (!exec_ctx->isStreamCapturable(stream)) { @@ -128,16 +136,16 @@ bool TRTRuntimeConfig::is_monolithic_capturable( // "lazy" kernel specialization only swaps specialized kernels mid-run when an // input has a dynamic dimension; for static-shape engines the kernels are fixed // at setup and the captured graph stays valid. Mirrors the Python check. - return !(rs.dynamic_shapes_kernel_specialization_strategy == "lazy" && has_dynamic_inputs); + return !(settings_.dynamic_shapes_kernel_specialization_strategy == "lazy" && has_dynamic_inputs); #else return true; #endif } std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { - os << "TRTRuntimeConfig{"; + os << "TRTRuntimeConfig{settings=" << cfg.settings().to_str(); #ifdef TRT_HAS_IRUNTIME_CONFIG - os << "config=" << (cfg.config ? "live" : "null"); + os << ", config=" << (cfg.config ? "live" : "null"); #endif os << "}"; return os; diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 94e9b3b5ac..5c9c3eaa44 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -4,35 +4,43 @@ #include #include #include +#include #include "NvInfer.h" +#include "core/runtime/RuntimeSettings.h" namespace torch_tensorrt { namespace core { namespace runtime { -struct RuntimeSettings; - -// Owns the live `IRuntimeConfig` (where supported) and the engine-local fallback -// `IRuntimeCache` used when no external `RuntimeCacheHandle` is attached via -// `RuntimeSettings`. The settings themselves (strategy strings, runtime_cache -// handle) live on `RuntimeSettings`; this struct applies them to TRT at -// `ensure_initialized` time. +// Owns the canonical `RuntimeSettings` for an engine, the live `IRuntimeConfig` +// derived from those settings (where supported), and translates strategy +// strings into TRT calls. All `TRT_HAS_IRUNTIME_CONFIG` / `TRT_MAJOR_RTX` +// branching is confined to this TU. // -// `IRuntimeConfig` and runtime-cache `#ifdef`s are confined to this TU. +// `TRTEngine` holds a `TRTRuntimeConfig` member; the engine itself does not +// store a separate `RuntimeSettings`. `engine.runtime_settings()` forwards +// here. struct TRTRuntimeConfig { - // Lazy-constructed live config. `nullptr` until first `ensure_initialized`. -#ifdef TRT_HAS_IRUNTIME_CONFIG - std::shared_ptr config; -#endif + TRTRuntimeConfig() = default; + explicit TRTRuntimeConfig(RuntimeSettings settings) : settings_(std::move(settings)) {} + + // Canonical user-facing runtime settings for this engine. Mutated only via + // `set_settings` so the live `IRuntimeConfig` stays in sync. + [[nodiscard]] RuntimeSettings const& settings() const noexcept { + return settings_; + } + + // Returns true iff `new_settings` differs from the current settings (i.e. + // the caller should recreate the `IExecutionContext`). On change the live + // `IRuntimeConfig` is invalidated; the next `ensure_initialized` rebuilds. + bool set_settings(RuntimeSettings new_settings); - // (Re)build the `IRuntimeConfig` from `rs`. Idempotent only if the previous - // `rs` was identical. Callers ensure the engine is the same across calls -- - // we don't memoize against `cuda_engine` here. - void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine, RuntimeSettings const& rs); + // (Re)build the `IRuntimeConfig` from `settings_`. Idempotent if the previous + // build was against identical settings. + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); - // Force the next `ensure_initialized` to rebuild from scratch. Used when - // settings change at runtime. + // Force the next `ensure_initialized` to rebuild from scratch. void reset(); // Lazy-init + create a fresh `IExecutionContext` honoring `allocation_strategy`. @@ -41,21 +49,27 @@ struct TRTRuntimeConfig { // `TRT_HAS_IRUNTIME_CONFIG` branching. [[nodiscard]] std::shared_ptr create_execution_context( nvinfer1::ICudaEngine* cuda_engine, - RuntimeSettings const& rs, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy); - // Returns true if TRT-RTX owns capture/replay for the given settings -- caller - // should then bypass its own `at::cuda::CUDAGraph` capture around enqueueV3. - // Always false on non-RTX builds. - [[nodiscard]] static bool uses_internal_capture(RuntimeSettings const& rs, bool cudagraphs_enabled) noexcept; + // Returns true if TRT-RTX owns capture/replay for the current settings -- + // caller should then bypass its own `at::cuda::CUDAGraph` capture around + // enqueueV3. Always false on non-RTX builds. + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const noexcept; // Returns true iff the execution context can be safely included in an outer // monolithic capture. Non-RTX builds always return true. - [[nodiscard]] static bool is_monolithic_capturable( - RuntimeSettings const& rs, + [[nodiscard]] bool is_monolithic_capturable( bool has_dynamic_inputs, nvinfer1::IExecutionContext* exec_ctx, - cudaStream_t stream) noexcept; + cudaStream_t stream) const noexcept; + +#ifdef TRT_HAS_IRUNTIME_CONFIG + // Lazy-constructed live config. `nullptr` until first `ensure_initialized`. + std::shared_ptr config; +#endif + + private: + RuntimeSettings settings_; }; std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index a773b1afa3..80936951ef 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -249,7 +249,7 @@ std::vector execute_engine(std::vector inputs, c10::intr // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would // collide, so we disable it one-shot here. bool effective_cudagraphs = cudagraphs_enabled; - if (TRTRuntimeConfig::uses_internal_capture(compiled_engine->runtime_settings(), cudagraphs_enabled)) { + if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { effective_cudagraphs = false; cudaStreamCaptureStatus capture_status; cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 17e0154c68..9b90aeffe0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -10,7 +10,7 @@ import torch if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import RuntimeSettings from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt._Device import Device diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 9ed9b5ef2e..c6f257d0c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -8,7 +8,7 @@ import torch if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import RuntimeSettings from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 0396c9cc8b..3f0571ce22 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -32,7 +32,10 @@ import torch_tensorrt if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import ( + RuntimeSettings, + TRTRuntimeConfig, + ) from torch._library.opaque_object import register_opaque_type from torch._opaque_base import OpaqueBase from torch_tensorrt._enums import dtype @@ -71,44 +74,6 @@ logger = logging.getLogger(__name__) -def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: - """Map strategy string to TRT enum. Only meaningful on TensorRT-RTX builds.""" - return { - "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, - "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, - "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, - }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) - - -def _get_cuda_graph_strategy(strategy_str: str) -> Any: - """Map strategy string to TRT CudaGraphStrategy enum. Only meaningful on RTX.""" - return { - "disabled": trt.CudaGraphStrategy.DISABLED, - "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, - }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) - - -def _normalize_runtime_cache( - rc: Any, -) -> Any: - """Accept ``None``, a path string, or a ``RuntimeCacheHandle``; return either - ``None`` or a ``RuntimeCacheHandle`` instance. - - String inputs are wrapped in a fresh per-engine implicit handle. The handle - is owned by the engine (saved on engine ``__del__``). - """ - from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle - - if rc is None or isinstance(rc, RuntimeCacheHandle): - return rc - if isinstance(rc, str): - return RuntimeCacheHandle(path=rc, autosave=True) - raise TypeError( - f"RuntimeSettings.runtime_cache must be None, a path string, or a " - f"RuntimeCacheHandle; got {type(rc).__name__}" - ) - - # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -254,7 +219,10 @@ def __init__( runtime_settings: Optional["RuntimeSettings"] = None, ) -> None: # Import here to avoid a circular dep at module-import time. - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import ( + RuntimeSettings, + TRTRuntimeConfig, + ) self._profile_execution = profile_execution self.profile_path_prefix = tempfile.gettempdir() @@ -276,18 +244,6 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # Initialized to ``None`` here so the destructor can run even if - # ``_setup_engine`` never executed. - self.runtime_config: Any = None - # Per-engine implicit cache handle, owned by this engine when - # ``runtime_settings.runtime_cache`` is supplied as a string path. - # ``None`` when ``runtime_settings.runtime_cache`` is an external - # handle (caller owns the lifecycle). - self._implicit_cache_handle: Any = None - # Engine-local IRuntimeCache used when no external handle is attached. - # Held as an instance attr so its lifetime matches the runtime_config it - # was set on -- TRT's set_runtime_cache borrows, doesn't own. - self._engine_local_runtime_cache: Any = None # When true, ``_execute_standard`` must skip manual torch.cuda.CUDAGraph # capture because TRT-RTX handles it internally. self._rtx_native_cudagraphs: bool = False @@ -295,12 +251,36 @@ def __init__( # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None - # User-facing runtime settings. Mutated by ``update_runtime_settings``. - self.runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() + # Owns RuntimeSettings + the live trt.IRuntimeConfig + the + # engine-implicit RuntimeCacheHandle. Hides all RTX feature gates. + self._trt_runtime_config: TRTRuntimeConfig = TRTRuntimeConfig( + runtime_settings or RuntimeSettings() + ) self._load_serialized_info(serialized_info) self._setup_engine() + # --- public property forwards --- + + @property + def runtime_settings(self) -> "RuntimeSettings": + """The current ``RuntimeSettings`` for this engine. + + Backed by ``self._trt_runtime_config.settings``; mutations go through + :meth:`update_runtime_settings`. + """ + return self._trt_runtime_config.settings + + @property + def runtime_config(self) -> Any: + """The live ``trt.IRuntimeConfig`` (or ``None`` on non-RTX builds).""" + return self._trt_runtime_config._live + + @property + def _implicit_cache_handle(self) -> Any: + """The engine-implicit ``RuntimeCacheHandle`` if a path-string compile-time hint was given.""" + return self._trt_runtime_config.implicit_cache_handle + def __del__(self) -> None: self.close() @@ -333,7 +313,10 @@ def __getstate__(self) -> Tuple[List[Any], str]: def __setstate__(self, state: Any) -> None: """Restore from C++-matching pickle state ``(serialized_info,)``.""" - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import ( + RuntimeSettings, + TRTRuntimeConfig, + ) self._profile_execution = False self.profile_path_prefix = tempfile.gettempdir() @@ -355,11 +338,6 @@ def __setstate__(self, state: Any) -> None: torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # See ``__init__`` for the rationale: pre-init these so a destructor - # firing on a partially-loaded engine never trips an ``AttributeError``. - self.runtime_config = None - self._implicit_cache_handle = None - self._engine_local_runtime_cache = None self._rtx_native_cudagraphs = False # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). @@ -367,7 +345,7 @@ def __setstate__(self, state: Any) -> None: # RuntimeSettings are NOT serialized -- restore defaults. Callers # who want runtime-mode overrides must reapply them post-load via # ``compiled.set_runtime_settings(...)`` or a runtime CM. - self.runtime_settings = RuntimeSettings() + self._trt_runtime_config = TRTRuntimeConfig(RuntimeSettings()) serialized_info = list(state[0]) engine_field = serialized_info[ENGINE_IDX] @@ -456,31 +434,24 @@ def get_serialized_metadata(self) -> str: return self.serialized_metadata def close(self) -> None: - """Persist any implicit runtime cache and release CUDA graph resources. + """Release CUDA graph resources. - Implicit handles (created by the engine from a string path in - ``runtime_settings.runtime_cache``) save here. External handles - from a ``runtime_cache`` CM save on the CM's ``__exit__`` instead. + Implicit runtime cache persistence is now driven by the + :class:`~torch_tensorrt.runtime._runtime_cache.RuntimeCacheHandle`'s + own ``__del__`` (with ``autosave_on_del=True``), so no explicit save + is needed here. """ - handle = self._implicit_cache_handle - if handle is not None: - try: - handle.save() - except Exception as e: # never raise from __del__ - logger.warning(f"Failed to save implicit runtime cache: {e}") self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: - if ENABLED_FEATURES.tensorrt_rtx: - assert self.runtime_config is not None - context = self.cuda_engine.create_execution_context(self.runtime_config) - else: - strategy = ( - trt.ExecutionContextAllocationStrategy.USER_MANAGED - if self.resource_allocation_strategy - else trt.ExecutionContextAllocationStrategy.STATIC - ) - context = self.cuda_engine.create_execution_context(strategy) + alloc_strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) + context = self._trt_runtime_config.create_execution_context( + self.cuda_engine, alloc_strategy + ) assert context is not None, "Failed to create execution context" return context @@ -495,16 +466,14 @@ def _setup_engine(self) -> None: logger.debug(f"Weight streaming budget set to {budget_bytes}B") self.cuda_engine.weight_streaming_budget_v2 = budget_bytes - # On TensorRT-RTX, build the IRuntimeConfig (runtime cache, - # dynamic-shape kernel specialization strategy, and CUDA graph - # strategy) up front so the one-and-only execution context picks it up. - if ENABLED_FEATURES.tensorrt_rtx: - self._setup_runtime_config() - self._rtx_native_cudagraphs = ( - self.runtime_settings.cuda_graph_strategy != "disabled" - ) - + # The TRTRuntimeConfig shim builds the live IRuntimeConfig (with cache + # + strategies) inside ``create_execution_context`` on RTX; no-op + # otherwise. Track the cudagraph-disabled-or-not state for the + # ``_execute_standard`` path to consult. self.context = self._create_execution_context() + self._rtx_native_cudagraphs = ENABLED_FEATURES.tensorrt_rtx and ( + self.runtime_settings.cuda_graph_strategy != "disabled" + ) if self._has_nccl_ops: from torch_tensorrt.distributed._nccl_utils import ( @@ -568,137 +537,43 @@ def _setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() - # --- TensorRT-RTX --- - - def _setup_runtime_config(self) -> None: - """Build an ``IRuntimeConfig`` sourced from ``self.runtime_settings``. - - The runtime cache field on RuntimeSettings can be ``None`` (per-engine - in-memory cache), a string path (engine creates an implicit handle and - saves on ``__del__``), or a ``RuntimeCacheHandle`` (external; caller - owns lifecycle). - """ - self.runtime_config = self.cuda_engine.create_runtime_config() - alloc_strategy = ( - trt.ExecutionContextAllocationStrategy.USER_MANAGED - if self.resource_allocation_strategy - else trt.ExecutionContextAllocationStrategy.STATIC - ) - self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) - self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( - _get_dynamic_shapes_kernel_strategy( - self.runtime_settings.dynamic_shapes_kernel_specialization_strategy - ) - ) - logger.info( - "Dynamic shapes kernel specialization strategy: " - f"{self.runtime_settings.dynamic_shapes_kernel_specialization_strategy}" - ) - self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - self.runtime_settings.cuda_graph_strategy - ) - logger.info(f"CUDA graph strategy: {self.runtime_settings.cuda_graph_strategy}") - - # Resolve the runtime cache. We only attach a cache to the runtime_config - # when the user explicitly opts in: passing a path string (engine creates - # an implicit handle, saves on ``__del__``) or a ``RuntimeCacheHandle`` - # (external, caller-managed). Default ``None`` leaves the runtime_config - # cache-less, matching pre-refactor behavior. - rc = self.runtime_settings.runtime_cache - if rc is None: - self._implicit_cache_handle = None - self._engine_local_runtime_cache = None - logger.debug( - "Runtime cache disabled (no RuntimeCacheHandle / path provided)." - ) - elif isinstance(rc, str): - # Per-engine disk-backed cache; engine owns the handle and saves - # on ``__del__`` (matches today's ``runtime_cache_path=`` semantics). - # We MUST keep a Python ref to the cache (TRT's ``set_runtime_cache`` - # only borrows) -- the handle holds it. - from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle - - cache = self.runtime_config.create_runtime_cache() - self._implicit_cache_handle = RuntimeCacheHandle( - cache=cache, path=rc, autosave=True - ) - self._engine_local_runtime_cache = None - try: - self._implicit_cache_handle.load() - except Exception as e: - logger.warning(f"Failed to load implicit runtime cache: {e}") - self.runtime_config.set_runtime_cache(cache) - else: - # External handle. Lifecycle owned by caller; the handle holds the ref. - cache = rc.ensure_cache(self.runtime_config) - self._implicit_cache_handle = None - self._engine_local_runtime_cache = None - self.runtime_config.set_runtime_cache(cache) - logger.info("TensorRT-RTX runtime config configured") + # --- TensorRT-RTX runtime-config delegation --- def update_runtime_settings(self, new_settings: "RuntimeSettings") -> None: """Apply new ``RuntimeSettings`` to this engine. - No-op fast-path when ``new_settings`` is field-equal to the current - settings. Otherwise: persist any prior implicit-cache contents, - rebuild ``runtime_config`` from the new settings, and recreate the - execution context so the new strategy values take effect on the next - enqueue. + Fast-paths on equality via ``TRTRuntimeConfig.set_settings``. On + change, the prior implicit cache (if any) is saved, the live + ``IRuntimeConfig`` is invalidated, and a fresh ``IExecutionContext`` + is created. """ - if new_settings == self.runtime_settings: + if not self._trt_runtime_config.set_settings(new_settings): return - # Persist the prior implicit cache before swapping handles; otherwise - # the str-path lifecycle would silently drop kernels JIT-compiled so - # far when the user moves to a different cache configuration. - prior_handle = self._implicit_cache_handle - if prior_handle is not None: - try: - prior_handle.save() - except Exception as e: - logger.warning(f"Failed to save implicit runtime cache on swap: {e}") - self.runtime_settings = new_settings - if ENABLED_FEATURES.tensorrt_rtx: - self._setup_runtime_config() - self._rtx_native_cudagraphs = ( - self.runtime_settings.cuda_graph_strategy != "disabled" - ) self.context = self._create_execution_context() + self._rtx_native_cudagraphs = ENABLED_FEATURES.tensorrt_rtx and ( + self.runtime_settings.cuda_graph_strategy != "disabled" + ) self.runtime_states.context_changed = True def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: - """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. - - On RTX, unsafe when the TRT-RTX context is not stream-capturable, or - when ``"lazy"`` kernel specialization can still fire (dynamic inputs). - """ - if not ENABLED_FEATURES.tensorrt_rtx: - return True + """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe.""" has_dynamic_input = any(DYNAMIC_DIM in shape for shape in self.input_shapes) - not_capturable = ( - not self.context.is_stream_capturable(stream.cuda_stream), - ( - self.runtime_settings.dynamic_shapes_kernel_specialization_strategy - == "lazy" - and has_dynamic_input - ), + return self._trt_runtime_config.is_monolithic_capturable( + has_dynamic_input, self.context, stream ) - return not any(not_capturable) def _enable_rtx_native_cudagraphs(self) -> None: """Switch this engine to TRT-RTX native CUDA graphs. - Sets the runtime config's ``cuda_graph_strategy`` to - ``WHOLE_GRAPH_CAPTURE`` and rebuilds the execution context so it - picks up the new strategy. No-op on non-RTX or when the runtime - config is not present. + Mutates settings via ``update_runtime_settings`` so the prior cache is + saved + a fresh context is created uniformly. No-op on non-RTX builds. """ - if self.runtime_config is None: + if not ENABLED_FEATURES.tensorrt_rtx: return - self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - "whole_graph_capture" + new_settings = self.runtime_settings.merge( + cuda_graph_strategy="whole_graph_capture" ) - self.context = self._create_execution_context() - self._rtx_native_cudagraphs = True + self.update_runtime_settings(new_settings) logger.info("Switched to TRT-RTX native CUDA graphs") # --- distributed / NCCL --- diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 62bafb6bd4..6512d896ea 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -3,9 +3,7 @@ import base64 import copy import logging -import os import pickle -import shutil from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch @@ -34,7 +32,7 @@ ) if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -141,7 +139,7 @@ def __init__( # Per-engine runtime mode controls. Defaults to ``RuntimeSettings()`` if # not supplied; the dataclass validates at ``__post_init__``. - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import RuntimeSettings self._runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() self.symbolic_shape_expressions = symbolic_shape_expressions @@ -346,36 +344,33 @@ def setup_engine(self) -> None: else: self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) self.execute_engine_op = torch.ops.tensorrt.execute_engine - # If the compile-time hint was a path string, pre-materialize a - # torchbind RuntimeCacheHandle here so we (a) own a Python-side - # reference that survives until the module is collected, and (b) - # can save the cache to disk in __del__ (the C++ engine has no - # Python __del__; file I/O lives on the Python side). Substitute - # the handle for the string in the settings so the dispatch - # passes the same handle through to TorchBind. + # If the compile-time hint was a path string, materialize a + # torchbind RuntimeCacheHandle here and wrap it in a Python + # ``RuntimeCacheHandle`` so the handle's own ``__del__`` saves + # the on-disk cache when the module is collected. (The C++ engine + # has no Python ``__del__``; file I/O lives on the Python side.) + # Substitute the torchbind handle for the string in the settings + # so dispatch passes the same handle through to TorchBind. rc = self._runtime_settings.runtime_cache if isinstance(rc, str) and rc: - handle = torch.classes.tensorrt.RuntimeCacheHandle(rc) - self._cpp_implicit_cache_handle = handle - self._cpp_implicit_cache_path = rc - self._runtime_settings = self._runtime_settings.merge( - runtime_cache=handle + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + tb = torch.classes.tensorrt.RuntimeCacheHandle(rc) + self._implicit_cache_handle: Any = RuntimeCacheHandle( + path=rc, autosave_on_del=True, torchbind_handle=tb ) - # Pre-load any existing on-disk cache so the engine sees - # warm contents on first inference. The first engine attach - # materializes the IRuntimeCache via createRuntimeCache(). - self._cpp_implicit_handle_pending_load = True + self._runtime_settings = self._runtime_settings.merge(runtime_cache=tb) else: - self._cpp_implicit_cache_handle = None - self._cpp_implicit_cache_path = None - self._cpp_implicit_handle_pending_load = False + self._implicit_cache_handle = None # Apply runtime settings to the C++ engine (no-op if defaults). self._dispatch_runtime_settings_to_engine(self._runtime_settings) # After dispatch the IRuntimeCache exists inside the handle; load # the on-disk bytes (filelocked) so they're picked up on first run. - if self._cpp_implicit_handle_pending_load: - self._load_cpp_implicit_cache() - self._cpp_implicit_handle_pending_load = False + if self._implicit_cache_handle is not None: + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.debug(f"Failed to load implicit runtime cache: {e}") # requires_native_multidevice is set by the C++ constructor from the serialized REQUIRES_NATIVE_MULTIDEVICE_IDX field. if self.engine.requires_native_multidevice: @@ -407,70 +402,6 @@ def setup_engine(self) -> None: # code cache and isn't reachable via module tree walking. register_md_engine(self.engine) - def _load_cpp_implicit_cache(self) -> None: - """Deserialize on-disk cache bytes into the torchbind handle. - - Mirrors :py:meth:`RuntimeCacheHandle.load` for the C++ runtime path. - No-op on first run (file absent) or if the IRuntimeCache hasn't been - materialized yet inside the C++ engine. - """ - handle = getattr(self, "_cpp_implicit_cache_handle", None) - path = getattr(self, "_cpp_implicit_cache_path", None) - if handle is None or not path or not handle.has_cache(): - return - if not os.path.exists(path): - return - try: - from filelock import FileLock - - with FileLock(path + ".lock").acquire(timeout=10): - with open(path, "rb") as f: - data = f.read() - if data: - # The torchbind `deserialize` takes a uint8 tensor; we wrap the - # raw bytes via ``frombuffer`` for a zero-copy view. - tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) - handle.deserialize(tensor) - logger.debug(f"Loaded runtime cache from {path} ({len(data)} bytes)") - except Exception as e: - logger.debug(f"Failed to load runtime cache from {path}: {e}") - - def _save_cpp_implicit_cache(self) -> None: - """Serialize the torchbind handle's IRuntimeCache to disk under filelock. - - Called from __del__. Suppresses all exceptions because __del__ may - run during interpreter shutdown when imports / filesystem ops can - fail in unpredictable ways. - """ - handle = getattr(self, "_cpp_implicit_cache_handle", None) - path = getattr(self, "_cpp_implicit_cache_path", None) - if handle is None or not path: - return - try: - if not handle.has_cache(): - return - tensor = handle.serialize() - if tensor.numel() == 0: - return - data = bytes(tensor.cpu().contiguous().numpy()) - from filelock import FileLock - - parent = os.path.dirname(path) - if parent: - os.makedirs(parent, exist_ok=True) - tmp = path + ".tmp" - with FileLock(path + ".lock").acquire(timeout=10): - with open(tmp, "wb") as f: - f.write(data) - shutil.move(tmp, path) - logger.debug(f"Saved runtime cache to {path} ({len(data)} bytes)") - except Exception: - # Best-effort: never raise out of __del__. - pass - - def __del__(self) -> None: - self._save_cpp_implicit_cache() - def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) dumped_metadata = pickle.dumps(metadata) @@ -545,7 +476,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: ) # RuntimeSettings are NOT serialized; restore defaults. Caller can # reapply via ``compiled.set_runtime_settings(...)`` or a CM after load. - from torch_tensorrt.runtime._runtime_settings import RuntimeSettings + from torch_tensorrt.runtime._runtime_config import RuntimeSettings self._runtime_settings = RuntimeSettings() if self._use_python_runtime: diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 478cca548d..5fd7dc7107 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -15,6 +15,5 @@ from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle, runtime_cache -from torch_tensorrt.runtime._runtime_config import runtime_config -from torch_tensorrt.runtime._runtime_settings import RuntimeSettings +from torch_tensorrt.runtime._runtime_config import RuntimeSettings, runtime_config from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index 8e977cd3cc..d66e606bf0 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -31,84 +31,127 @@ class RuntimeCacheHandle: - """Wraps a ``trt.IRuntimeCache`` + optional disk path / autosave config. + """Wraps a ``trt.IRuntimeCache`` (or a torchbind sibling) + optional disk path. - Two ways an instance comes into being: + Three construction patterns differ in *who else holds a reference*, which + drives the ``autosave_on_del`` flag: 1. **Engine-implicit** (compile-time hint): when an engine sees - ``RuntimeSettings(runtime_cache="/path")``, it materializes a - handle internally during ``_setup_runtime_config`` -- the engine - owns the lifecycle and saves on ``__del__``. - - 2. **Runtime CM** (shared): the :func:`runtime_cache` CM bootstraps from - the first engine under target, creates a cache, wraps it here, and - attaches the handle to all engines for the duration of the ``with`` - block. The CM saves on ``__exit__``. - - Both paths produce the same handle shape; the difference is who owns - the lifecycle. + ``RuntimeSettings(runtime_cache="/path")``, the engine's + ``TRTRuntimeConfig`` materializes a handle internally with + ``autosave_on_del=True``. No other Python object holds the handle, + so ``__del__`` writes the cache to disk on the engine's last release. + + 2. **Runtime CM** (shared, multi-engine): the :func:`runtime_cache` CM + constructs a handle with ``autosave_on_del=False`` and explicitly + calls ``handle.save()`` on ``__exit__``. The handle's ``__del__`` + no-ops since the CM already saved. + + 3. **User-constructed** (advanced): hand-built handles default to + ``autosave_on_del=False`` so save timing stays under the user's + control. Opt in with ``RuntimeCacheHandle(path=..., autosave_on_del=True)`` + for with-block-style autosave on hand-built handles. + + Bytes are sourced from whichever of ``_cache`` (Python pybind + ``trt.IRuntimeCache``) or ``_torchbind`` (TorchBind + ``RuntimeCacheHandle`` sibling) is populated; the Python runtime path + populates the former, the C++ runtime path populates the latter. """ def __init__( self, cache: Any = None, path: str = "", - autosave: bool = True, + autosave_on_del: bool = False, + torchbind_handle: Any = None, ) -> None: - # ``cache`` is a ``trt.IRuntimeCache`` once materialized. May be None - # at construction if the handle is built before any engine has had a - # chance to call ``runtime_config.create_runtime_cache()``. + # ``cache`` is a ``trt.IRuntimeCache`` once materialized (Python rt). + # ``torchbind_handle`` is a ``torch.classes.tensorrt.RuntimeCacheHandle`` + # for the C++ runtime path, exposing serialize/deserialize as tensors. + # Exactly zero or one is populated for a given handle. self._cache = cache + self._torchbind = torchbind_handle self.path = path - self.autosave = autosave + self.autosave_on_del = autosave_on_del self._lock = threading.Lock() @property def cache(self) -> Any: - """The underlying ``trt.IRuntimeCache``. ``None`` if not yet materialized.""" + """The underlying Python pybind ``trt.IRuntimeCache``. ``None`` if not yet materialized or if backed by a torchbind sibling.""" return self._cache def ensure_cache(self, runtime_config: Any) -> Any: - """Idempotent. First caller materializes via ``runtime_config.create_runtime_cache()``.""" + """Idempotent. First caller materializes via ``runtime_config.create_runtime_cache()``. + + Only meaningful for the Python-runtime path (``_cache``). The C++ + runtime materializes its cache inside the engine and exposes bytes + through the torchbind sibling. + """ with self._lock: if self._cache is None: self._cache = runtime_config.create_runtime_cache() return self._cache + def _read_bytes(self) -> Optional[bytes]: + """Serialize whichever of ``_cache`` or ``_torchbind`` is populated.""" + if self._cache is not None: + host_mem = self._cache.serialize() + if host_mem is None or host_mem.nbytes == 0: + return None + return bytes(memoryview(host_mem)) + if self._torchbind is not None and self._torchbind.has_cache(): + tensor = self._torchbind.serialize() + if tensor.numel() == 0: + return None + return bytes(tensor.cpu().contiguous().numpy()) + return None + + def _write_bytes(self, data: bytes) -> None: + """Deserialize ``data`` into whichever of ``_cache`` or ``_torchbind`` is populated.""" + if self._cache is not None: + self._cache.deserialize(data) + return + if self._torchbind is not None and self._torchbind.has_cache(): + tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) + self._torchbind.deserialize(tensor) + return + def load(self, path: Optional[str] = None) -> None: - """Read bytes from disk and deserialize into ``self._cache``. + """Read bytes from disk and deserialize into the underlying cache. - No-op if ``self._cache`` is None, the resolved path is empty, or the - file doesn't exist (first run). Caller must ensure no enqueue is + No-op if no cache backing is present, the resolved path is empty, or + the file doesn't exist (first run). Caller must ensure no enqueue is concurrently writing (the CM enforces this by ordering load before - engine attach; ``ensure_cache`` is called inside the engine setup). + engine attach; ``ensure_cache`` is called inside engine setup). """ target = path if path is not None else self.path - if not target or self._cache is None: + if not target: + return + if self._cache is None and self._torchbind is None: return - from filelock import FileLock - if not os.path.exists(target): return # first run; nothing to load + from filelock import FileLock + with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): with open(target, "rb") as f: data = f.read() if data: - self._cache.deserialize(data) + self._write_bytes(data) logger.debug(f"Loaded runtime cache from {target} ({len(data)} bytes)") def save(self, path: Optional[str] = None) -> None: - """Serialize ``self._cache`` and write to disk under a filelock. + """Serialize the underlying cache and write to disk under a filelock. - No-op if path is empty or cache wasn't materialized. Caller must + No-op if path is empty or the cache wasn't materialized. Caller must ensure no enqueue is concurrently writing (the CM detaches the cache from all engines before calling save in ``__exit__``). """ target = path if path is not None else self.path - if not target or self._cache is None: + if not target: return - host_mem = self._cache.serialize() - if host_mem is None or host_mem.nbytes == 0: + data = self._read_bytes() + if not data: return from filelock import FileLock @@ -118,9 +161,21 @@ def save(self, path: Optional[str] = None) -> None: tmp = target + ".tmp" with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): with open(tmp, "wb") as f: - f.write(memoryview(host_mem)) + f.write(data) shutil.move(tmp, target) - logger.debug(f"Saved runtime cache to {target} ({host_mem.nbytes} bytes)") + logger.debug(f"Saved runtime cache to {target} ({len(data)} bytes)") + + def __del__(self) -> None: + # Best-effort autosave for engine-implicit handles. The CM disables + # this (``autosave_on_del=False``) since it saves on ``__exit__``; + # user-constructed handles default to disabled so save timing stays + # under the user's control. ``__del__`` can fire during interpreter + # shutdown when imports/filesystem ops fail unpredictably -- swallow. + if self.autosave_on_del and self.path: + try: + self.save() + except Exception: + pass def __eq__(self, other: object) -> bool: # Identity equality so passing the same handle twice through @@ -132,8 +187,9 @@ def __hash__(self) -> int: def __repr__(self) -> str: return ( - f"RuntimeCacheHandle(path={self.path!r}, autosave={self.autosave}, " - f"materialized={self._cache is not None})" + f"RuntimeCacheHandle(path={self.path!r}, " + f"autosave_on_del={self.autosave_on_del}, " + f"materialized={self._cache is not None or self._torchbind is not None})" ) @@ -190,9 +246,12 @@ def __enter__(self) -> RuntimeCacheHandle: # 2. Materialize the cache via the bootstrap engine's runtime_config. # (The cache returned is free-floating; ownership transfers to the handle.) + # ``autosave_on_del=False`` because the CM saves explicitly on ``__exit__``; + # letting ``__del__`` also save would double-write when ``rc`` falls out + # of scope after the with-block. cache_obj = bootstrap_engine.runtime_config.create_runtime_cache() self.handle = RuntimeCacheHandle( - cache=cache_obj, path=self.path, autosave=self.autosave + cache=cache_obj, path=self.path, autosave_on_del=False ) # 3. Load from disk if path was given. diff --git a/py/torch_tensorrt/runtime/_runtime_config.py b/py/torch_tensorrt/runtime/_runtime_config.py index 635cdc5902..7cab20190b 100644 --- a/py/torch_tensorrt/runtime/_runtime_config.py +++ b/py/torch_tensorrt/runtime/_runtime_config.py @@ -1,28 +1,325 @@ -"""Per-engine runtime-settings context manager. +"""Runtime settings + the TRTRuntimeConfig shim + the ``runtime_config`` CM. -``runtime_config(target_or_targets, **kw)`` is the one runtime CM that toggles -``RuntimeSettings`` on every TRT engine reachable under the listed targets. -Other CMs (``runtime_cache``, ``set_cuda_graph_strategy``, -``set_dynamic_shapes_kernel_strategy``) are thin sugar that delegate here. +This module groups three closely related concepts together: -Walks ``named_modules()`` once on enter, snapshots prior settings per engine, -calls ``mod.set_runtime_settings(merged)`` per engine. Restores on exit using -the same snapshot dict. +* :class:`RuntimeSettings` -- the user-facing, frozen dataclass of runtime-only + knobs sampled at IExecutionContext creation (cuda_graph_strategy, + dynamic_shapes_kernel_specialization_strategy, runtime_cache). +* :class:`TRTRuntimeConfig` -- the engine-internal shim that owns a + :class:`RuntimeSettings` and the live ``trt.IRuntimeConfig`` derived from + it. All ``ENABLED_FEATURES.tensorrt_rtx`` branching lives inside; callers in + ``_TRTEngine`` and ``_TorchTensorRTModule`` stay uniform. Mirrors the C++ + ``torch_tensorrt::core::runtime::TRTRuntimeConfig`` struct. +* :func:`runtime_config` -- the runtime-mode context manager that toggles + settings on every TRT submodule under a target for the duration of a + ``with`` block. -Yields the target (or tuple of targets) so users can write -``with runtime_config(model, ...) as m: m(*inputs)``. +Three ways to use ``RuntimeSettings``: + +1. **Compile-time hint** (recommended fast path) -- prime the engine with the + desired initial values so no CM enter/exit recreate is needed:: + + compiled = torchtrt.compile( + model, ..., + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) + +2. **Runtime context manager** -- toggle settings inside a ``with`` block. +3. **Programmatic** -- call ``module.set_runtime_settings(rs)`` directly. + +``RuntimeSettings`` is intentionally NOT part of ``CompilationSettings`` and is +NOT serialized into the engine tuple (per GitHub pytorch/TensorRT#4310). It's +purely an in-memory initialization parameter / runtime override state. """ from __future__ import annotations import dataclasses -from typing import Any, Dict, Sequence, Tuple, Union +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import torch -from torch_tensorrt.runtime._runtime_settings import RuntimeSettings +from torch_tensorrt._features import ENABLED_FEATURES + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + +logger = logging.getLogger(__name__) + +# Validation maps for the dataclass post-init. The TRT enum mappings live +# inside ``TRTRuntimeConfig._apply_settings`` so non-RTX imports don't trip +# over RTX-only enum symbols at module load time. +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} + + +@dataclass(frozen=True) +class RuntimeSettings: + """Per-engine runtime-only knobs sampled at IExecutionContext creation. + + Fields: + dynamic_shapes_kernel_specialization_strategy: ``"lazy" | "eager" | "none"``. + TRT-RTX-only; no-op on standard TensorRT. + cuda_graph_strategy: ``"disabled" | "whole_graph_capture"``. TRT-RTX-only. + runtime_cache: ``None``, a disk path string, or a + :class:`RuntimeCacheHandle`. ``None`` ⇒ no cache attached. A + string is honored at engine construction time and primes a + per-engine disk-backed cache (engine owns the implicit handle and + it saves on ``__del__``). A handle is the shared-cache form, + typically obtained from :func:`torch_tensorrt.runtime.runtime_cache` + -- multiple engines attaching the same handle share one + ``IRuntimeCache``. + + Equality compares all fields; for ``runtime_cache``, handle equality is + by identity (same handle ⇒ same cache). + """ + + dynamic_shapes_kernel_specialization_strategy: str = "lazy" + cuda_graph_strategy: str = "disabled" + runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = None # noqa: F821 + + def __post_init__(self) -> None: + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + "Invalid dynamic_shapes_kernel_specialization_strategy: " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}. " + f"Expected one of {list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP)}." + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy: {self.cuda_graph_strategy!r}. " + f"Expected one of {list(_CUDA_GRAPH_STRATEGY_MAP)}." + ) + + def merge(self, **overrides: Any) -> "RuntimeSettings": + """Return a new ``RuntimeSettings`` with ``overrides`` applied on top of self.""" + unknown = set(overrides) - {f.name for f in dataclasses.fields(self)} + if unknown: + raise TypeError( + f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " + f"Valid fields: {[f.name for f in dataclasses.fields(self)]}." + ) + return dataclasses.replace(self, **overrides) + + +class TRTRuntimeConfig: + """Owns ``RuntimeSettings`` + the live ``trt.IRuntimeConfig`` for one engine. + + All TRT-RTX feature-flag branching lives in this class -- callers in + ``_TRTEngine`` and ``_TorchTensorRTModule`` stay uniform. Mirrors the C++ + ``TRTRuntimeConfig`` struct. + + On non-RTX builds, ``self._live`` stays ``None`` and the strategy / + runtime-cache plumbing is short-circuited; ``create_execution_context`` + picks the legacy ``cuda_engine.create_execution_context(strategy)`` + overload. + """ + + def __init__(self, settings: Optional[RuntimeSettings] = None) -> None: + self._settings: RuntimeSettings = settings or RuntimeSettings() + # Live trt.IRuntimeConfig (RTX) or None (non-RTX / pre-init). + self._live: Any = None + # Engine-implicit RuntimeCacheHandle when settings.runtime_cache is a + # string path; None when external handle / no cache. + self._implicit_cache_handle: Any = None + + @property + def settings(self) -> RuntimeSettings: + """The current ``RuntimeSettings``. Mutate only via :meth:`set_settings`.""" + return self._settings + + @property + def implicit_cache_handle(self) -> Any: + """The engine-implicit ``RuntimeCacheHandle`` if any, else None. + + Set when ``settings.runtime_cache`` is a path string. The handle's + ``__del__`` persists kernels JIT-compiled during the engine's lifetime + when ``autosave_on_del`` is True (the default for implicit handles). + """ + return self._implicit_cache_handle + + def set_settings(self, new: RuntimeSettings) -> bool: + """Apply ``new`` settings. Returns True iff the value actually changed. + + On change, the prior implicit handle (if any) is saved before being + replaced, the live ``IRuntimeConfig`` is invalidated, and callers + should recreate the ``IExecutionContext``. + """ + if new == self._settings: + return False + prior = self._implicit_cache_handle + if prior is not None: + try: + prior.save() + except Exception as e: # never raise from setting swap + logger.warning(f"Failed to save implicit runtime cache on swap: {e}") + self._settings = new + self._live = None + self._implicit_cache_handle = None + return True + + def ensure_initialized(self, cuda_engine: Any) -> None: + """Lazy-create the live ``trt.IRuntimeConfig`` and apply settings. + + No-op on non-TRT-RTX builds, where there is no ``IRuntimeConfig`` to + configure. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return + if self._live is not None: + return + self._live = cuda_engine.create_runtime_config() + self._apply_settings() + + def reset(self) -> None: + """Drop the live ``IRuntimeConfig``; the next ``ensure_initialized`` rebuilds.""" + self._live = None + self._implicit_cache_handle = None + + def create_execution_context( + self, + cuda_engine: Any, + allocation_strategy: Any, + ) -> Any: + """Lazy-init + create a fresh ``IExecutionContext``. + + Picks the right ``cuda_engine.create_execution_context`` overload + (``IRuntimeConfig`` vs ``ExecutionContextAllocationStrategy``) so + callers stay free of any ``ENABLED_FEATURES.tensorrt_rtx`` branching. + """ + if ENABLED_FEATURES.tensorrt_rtx: + self.ensure_initialized(cuda_engine) + assert self._live is not None + self._live.set_execution_context_allocation_strategy(allocation_strategy) + return cuda_engine.create_execution_context(self._live) + return cuda_engine.create_execution_context(allocation_strategy) + + def uses_internal_capture(self, cudagraphs_enabled: bool) -> bool: + """Returns True if TRT-RTX owns capture/replay for the current settings. + + Caller should then bypass its own ``torch.cuda.CUDAGraph`` capture + around ``execute_async_v3``. Always False on non-RTX builds. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return False + return self._settings.cuda_graph_strategy != "disabled" or cudagraphs_enabled + + def is_monolithic_capturable( + self, + has_dynamic_inputs: bool, + context: Any, + stream: Any, + ) -> bool: + """Returns True iff this engine can be safely wrapped by an outer monolithic capture. + + Mirrors C++ ``TRTRuntimeConfig::is_monolithic_capturable``. Non-RTX + builds always return True. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True + if not context.is_stream_capturable(stream.cuda_stream): + return False + # "lazy" kernel specialization swaps specialized kernels mid-run when an + # input has a dynamic dimension; for static-shape engines the kernels + # are fixed at setup and the captured graph stays valid. + return not ( + self._settings.dynamic_shapes_kernel_specialization_strategy == "lazy" + and has_dynamic_inputs + ) + + # ------------------------------------------------------------------ + # Internal: apply self._settings to self._live + # ------------------------------------------------------------------ + + def _apply_settings(self) -> None: + """Apply ``self._settings`` to the live ``trt.IRuntimeConfig``. + + Resolves ``runtime_cache``: + - ``None`` ⇒ no cache attached. + - ``str`` path ⇒ create an engine-implicit ``RuntimeCacheHandle`` here + and attach. Engine owns lifecycle; handle's ``__del__`` saves. + - ``RuntimeCacheHandle`` ⇒ external; caller owns lifecycle. + """ + # Deferred imports: trt is import-aliased to tensorrt_rtx on RTX builds, + # and _runtime_cache imports this module's RuntimeSettings. + import tensorrt as trt + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + self._live.dynamic_shapes_kernel_specialization_strategy = ( + self._to_trt_ds_strategy(trt) + ) + logger.info( + "Dynamic shapes kernel specialization strategy: " + f"{self._settings.dynamic_shapes_kernel_specialization_strategy}" + ) + self._live.cuda_graph_strategy = self._to_trt_cg_strategy(trt) + logger.info(f"CUDA graph strategy: {self._settings.cuda_graph_strategy}") + + rc = self._settings.runtime_cache + if rc is None: + self._implicit_cache_handle = None + logger.debug( + "Runtime cache disabled (no RuntimeCacheHandle / path provided)." + ) + elif isinstance(rc, str): + cache = self._live.create_runtime_cache() + self._implicit_cache_handle = RuntimeCacheHandle( + cache=cache, path=rc, autosave_on_del=True + ) + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.warning(f"Failed to load implicit runtime cache: {e}") + self._live.set_runtime_cache(cache) + else: + # External RuntimeCacheHandle. Caller owns lifecycle; the handle + # holds the IRuntimeCache reference. + cache = rc.ensure_cache(self._live) + self._implicit_cache_handle = None + self._live.set_runtime_cache(cache) + logger.info("TensorRT-RTX runtime config configured") + + def _to_trt_ds_strategy(self, trt: Any) -> Any: + return { + "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, + "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, + "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, + }[self._settings.dynamic_shapes_kernel_specialization_strategy] + + def _to_trt_cg_strategy(self, trt: Any) -> Any: + return { + "disabled": trt.CudaGraphStrategy.DISABLED, + "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + }[self._settings.cuda_graph_strategy] + + +# --------------------------------------------------------------------------- +# runtime_config(...) CM and factory +# --------------------------------------------------------------------------- class _RuntimeConfigContextManager: + """Pool CM that applies ``RuntimeSettings`` overrides to every TRT submodule. + + Walks ``named_modules()`` once on enter, snapshots prior settings per + engine, calls ``mod.set_runtime_settings(merged)`` per engine. Restores on + exit using the same snapshot dict. + + Yields the target (or tuple of targets) so users can write + ``with runtime_config(model, ...) as m: m(*inputs)``. + """ + def __init__( self, target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], @@ -48,7 +345,7 @@ def __init__( # Engine ↔ prior RuntimeSettings snapshot; populated on enter. self._saved: Dict[Any, RuntimeSettings] = {} - def __enter__(self) -> Union[torch.nn.Module, Tuple[torch.nn.Module, ...]]: + def __enter__(self) -> Union["torch.nn.Module", Tuple["torch.nn.Module", ...]]: # Deferred import to avoid a circular dependency at module-load time. from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( TorchTensorRTModule, diff --git a/py/torch_tensorrt/runtime/_runtime_settings.py b/py/torch_tensorrt/runtime/_runtime_settings.py deleted file mode 100644 index 738315d555..0000000000 --- a/py/torch_tensorrt/runtime/_runtime_settings.py +++ /dev/null @@ -1,98 +0,0 @@ -"""User-facing runtime-only knobs for TRT-RTX engines. - -A knob belongs in :class:`RuntimeSettings` iff changing it requires recreating -the ``IExecutionContext``. Per-execute flags (``cudagraphs_mode``, -``multi_device_safe_mode``, ``pre_allocated_outputs``) stay as their existing -process-global setters. - -Three ways to use: - -1. **Compile-time hint** (recommended fast path) -- prime the engine with the - desired initial values so no CM enter/exit recreate is needed:: - - compiled = torchtrt.compile( - model, ..., - runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), - ) - -2. **Runtime context manager** -- toggle settings inside a ``with`` block. See - :func:`torch_tensorrt.runtime.runtime_config`. - -3. **Programmatic** -- call ``module.set_runtime_settings(rs)`` directly. - -``RuntimeSettings`` is intentionally NOT part of ``CompilationSettings`` and is -NOT serialized into the engine tuple (per GitHub pytorch/TensorRT#4310). It's -purely an in-memory initialization parameter / runtime override state. -""" - -from __future__ import annotations - -import dataclasses -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Union - -if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle - -# Validation maps used by both the engine setup path and the dataclass post-init. -_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { - "lazy": 0, - "eager": 1, - "none": 2, -} -_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { - "disabled": 0, - "whole_graph_capture": 1, -} - - -@dataclass(frozen=True) -class RuntimeSettings: - """Per-engine runtime-only knobs sampled at IExecutionContext creation. - - Fields: - dynamic_shapes_kernel_specialization_strategy: ``"lazy" | "eager" | "none"``. - TRT-RTX-only; no-op on standard TensorRT. - cuda_graph_strategy: ``"disabled" | "whole_graph_capture"``. TRT-RTX-only. - runtime_cache: ``None``, a disk path string, or a - :class:`RuntimeCacheHandle`. ``None`` ⇒ each engine has an in-memory - cache local to itself. A string is honored at engine construction - time and primes a per-engine disk-backed cache (matches today's - ``runtime_cache_path=`` behavior; saved on engine ``__del__``). - A handle is the shared-cache form, typically obtained from - :func:`torch_tensorrt.runtime.runtime_cache` -- multiple engines - attaching the same handle share one ``IRuntimeCache``. - - Equality compares all fields; for ``runtime_cache``, handle equality is - by identity (same handle ⇒ same cache). - """ - - dynamic_shapes_kernel_specialization_strategy: str = "lazy" - cuda_graph_strategy: str = "disabled" - runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = None # noqa: F821 - - def __post_init__(self) -> None: - if ( - self.dynamic_shapes_kernel_specialization_strategy - not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP - ): - raise ValueError( - "Invalid dynamic_shapes_kernel_specialization_strategy: " - f"{self.dynamic_shapes_kernel_specialization_strategy!r}. " - f"Expected one of {list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP)}." - ) - if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: - raise ValueError( - f"Invalid cuda_graph_strategy: {self.cuda_graph_strategy!r}. " - f"Expected one of {list(_CUDA_GRAPH_STRATEGY_MAP)}." - ) - - def merge(self, **overrides: Any) -> "RuntimeSettings": - """Return a new ``RuntimeSettings`` with ``overrides`` applied on top of self.""" - unknown = set(overrides) - {f.name for f in dataclasses.fields(self)} - if unknown: - raise TypeError( - f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " - f"Valid fields: {[f.name for f in dataclasses.fields(self)]}." - ) - return dataclasses.replace(self, **overrides) diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index c39ccb1b79..57d2ec6a1d 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -202,6 +202,55 @@ def test_runtime_cache_on_empty_target_raises(self): with runtime_cache(empty, path): pass + def test_cm_does_not_double_save_on_rc_gc(self): + """CM yields handle with autosave_on_del=False; only one save happens. + + Regression: if the CM-yielded handle had autosave_on_del=True, the + handle's __del__ would re-save after the CM's __exit__ already wrote + the file. We disable autosave_on_del on CM-created handles to avoid + that double-write. + """ + compiled, inputs = _compile_simple() + save_calls = [] + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache(compiled, path) as rc: + # CM-created handle must not autosave on del (CM saves explicitly). + self.assertFalse(rc.autosave_on_del) + original_save = rc.save + + def _tracking_save(p=None): + save_calls.append(p) + return original_save(p) + + rc.save = _tracking_save # type: ignore[method-assign] + _ = compiled(*inputs) + # CM.__exit__ saves exactly once; rc going out of scope triggers + # __del__ but autosave_on_del is False, so no second save. + del rc + gc.collect() + self.assertEqual(len(save_calls), 1, f"Expected one save, got {save_calls}") + self.assertTrue(os.path.exists(path)) + + +class TestRuntimeCacheHandleAutosave(TestCase): + """Whitebox tests for RuntimeCacheHandle.autosave_on_del semantics.""" + + def test_user_built_handle_no_autosave_by_default(self): + """Hand-built handle defaults to autosave_on_del=False; nothing on GC.""" + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + handle = RuntimeCacheHandle(path=path) + self.assertFalse(handle.autosave_on_del) + del handle + gc.collect() + self.assertFalse( + os.path.exists(path), + "User-built handle with autosave_on_del=False should not save on GC", + ) + if __name__ == "__main__": run_tests() From 4bfa982115b8688007a5893ef83077b49d2b8a41 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 19:36:47 -0700 Subject: [PATCH 4/7] runtime: address PR review feedback Five follow-up changes responding to PR review comments: * **Fold strategy sugar into ``_runtime_config.py``.** Delete ``_dynamic_shapes_kernel_strategy.py`` and ``_cuda_graph_strategy.py``; ``set_dynamic_shapes_kernel_strategy`` / ``set_cuda_graph_strategy`` now live alongside the ``runtime_config`` CM they delegate to. ``torch_tensorrt/runtime/__init__.py`` re-exports them from the consolidated module. * **Hoist ``RuntimeSettings`` defaults into ``_defaults.py``.** Three new constants (``DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY``, ``CUDA_GRAPH_STRATEGY``, ``RUNTIME_CACHE_PATH``) mirror the compilation-settings pattern. ``RUNTIME_CACHE_PATH`` defaults to a per-user temp file similar to ``ENGINE_CACHE_DIR``, so users get a disk-backed runtime cache without explicit opt-in; override via ``RuntimeSettings(runtime_cache="/path")`` or the ``runtime_cache`` CM. Test_000 and test_004 updated to reflect the new default. * **Warn on non-RTX ``RuntimeSettings`` construction.** ``__post_init__`` now emits a one-shot ``UserWarning`` on regular TRT builds (gated by ``ENABLED_FEATURES.tensorrt_rtx``) so users see that the settings have no effect. * **Drop ``TYPE_CHECKING`` string forward-refs for ``RuntimeSettings``.** Direct top-level imports across ``_compiler.py``, ``_conversion.py``, ``_TRTEngine.py`` and ``_TorchTensorRTModule.py``; bare ``Optional[RuntimeSettings]`` annotations everywhere. Deferred imports inside ``__init__`` / ``__setstate__`` removed. All 51 runtime tests pass (test_004 12/12, test_000 12/12, test_001 ds 14/14, test_001 cg 13/13). --- py/torch_tensorrt/dynamo/_compiler.py | 12 ++-- py/torch_tensorrt/dynamo/_defaults.py | 13 ++++ .../dynamo/conversion/_conversion.py | 8 +-- .../dynamo/runtime/_TRTEngine.py | 25 ++----- .../dynamo/runtime/_TorchTensorRTModule.py | 18 ++--- py/torch_tensorrt/runtime/__init__.py | 11 +-- .../runtime/_cuda_graph_strategy.py | 26 ------- .../_dynamic_shapes_kernel_strategy.py | 28 -------- py/torch_tensorrt/runtime/_runtime_config.py | 67 +++++++++++++++++-- .../dynamo/runtime/test_000_runtime_cache.py | 10 ++- .../runtime/test_004_runtime_settings.py | 5 +- 11 files changed, 111 insertions(+), 112 deletions(-) delete mode 100644 py/torch_tensorrt/runtime/_cuda_graph_strategy.py delete mode 100644 py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9b90aeffe0..0023a0a420 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,12 +5,9 @@ import os import platform import warnings -from typing import TYPE_CHECKING, Any, Collection, List, Optional, Sequence, Union +from typing import Any, Collection, List, Optional, Sequence, Union import torch - -if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_config import RuntimeSettings from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -56,6 +53,7 @@ to_torch_device, to_torch_tensorrt_device, ) +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -112,7 +110,7 @@ def cross_compile_for_windows( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -456,7 +454,7 @@ def compile( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -876,7 +874,7 @@ def compile_module( engine_cache: Optional[BaseEngineCache] = None, *, _debugger_config: Optional[DebuggerConfig] = None, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4a8078dd1d..e163e830a1 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -81,6 +81,19 @@ tempfile.gettempdir(), f"torch_tensorrt_{current_user}/debug_logs" ) +# --------------------------------------------------------------------------- +# Runtime-only knobs (see torch_tensorrt.runtime.RuntimeSettings). Defaults +# live here to mirror compilation-settings convention; the dataclass imports +# from this module. +# --------------------------------------------------------------------------- +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" +# Default to a per-user temp file (mirrors ENGINE_CACHE_DIR). Users can override +# via ``RuntimeSettings(runtime_cache="/different/path")`` or a runtime CM. +RUNTIME_CACHE_PATH = os.path.join( + tempfile.gettempdir(), f"torch_tensorrt_{current_user}/runtime_cache.bin" +) + def default_device() -> Device: return Device(gpu_id=torch.cuda.current_device()) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index c6f257d0c9..fc886656a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,13 +2,10 @@ import io import logging -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence +from typing import Any, Dict, List, NamedTuple, Optional, Sequence import tensorrt as trt import torch - -if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_config import RuntimeSettings from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input @@ -28,6 +25,7 @@ release_host_and_device_memory, ) from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -336,7 +334,7 @@ def convert_module( settings: CompilationSettings = CompilationSettings(), name: str = "", engine_cache: Optional[BaseEngineCache] = None, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, ) -> TorchTensorRTModule: """Convert an FX module to a TRT module Args: diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 3f0571ce22..1bfd3a21e9 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -16,7 +16,6 @@ from contextlib import nullcontext from types import SimpleNamespace from typing import ( - TYPE_CHECKING, Any, ContextManager, Dict, @@ -30,12 +29,6 @@ import torch import torch.distributed as dist import torch_tensorrt - -if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_config import ( - RuntimeSettings, - TRTRuntimeConfig, - ) from torch._library.opaque_object import register_opaque_type from torch._opaque_base import OpaqueBase from torch_tensorrt._enums import dtype @@ -62,6 +55,7 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._runtime_config import RuntimeSettings, TRTRuntimeConfig from torch_tensorrt.runtime._utils import ( _is_switch_required, _select_rt_device, @@ -216,14 +210,8 @@ def __init__( serialized_info: SerializedTensorRTEngineFmt, *, profile_execution: bool = False, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, ) -> None: - # Import here to avoid a circular dep at module-import time. - from torch_tensorrt.runtime._runtime_config import ( - RuntimeSettings, - TRTRuntimeConfig, - ) - self._profile_execution = profile_execution self.profile_path_prefix = tempfile.gettempdir() self.use_pre_allocated_outputs = False @@ -263,7 +251,7 @@ def __init__( # --- public property forwards --- @property - def runtime_settings(self) -> "RuntimeSettings": + def runtime_settings(self) -> RuntimeSettings: """The current ``RuntimeSettings`` for this engine. Backed by ``self._trt_runtime_config.settings``; mutations go through @@ -313,11 +301,6 @@ def __getstate__(self) -> Tuple[List[Any], str]: def __setstate__(self, state: Any) -> None: """Restore from C++-matching pickle state ``(serialized_info,)``.""" - from torch_tensorrt.runtime._runtime_config import ( - RuntimeSettings, - TRTRuntimeConfig, - ) - self._profile_execution = False self.profile_path_prefix = tempfile.gettempdir() self.use_pre_allocated_outputs = False @@ -539,7 +522,7 @@ def _setup_engine(self) -> None: # --- TensorRT-RTX runtime-config delegation --- - def update_runtime_settings(self, new_settings: "RuntimeSettings") -> None: + def update_runtime_settings(self, new_settings: RuntimeSettings) -> None: """Apply new ``RuntimeSettings`` to this engine. Fast-paths on equality via ``TRTRuntimeConfig.set_settings``. On diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 6512d896ea..f53e123b31 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -4,7 +4,7 @@ import copy import logging import pickle -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from torch_tensorrt._Device import Device @@ -30,9 +30,7 @@ serialize_binding_names, serialize_device_info, ) - -if TYPE_CHECKING: - from torch_tensorrt.runtime._runtime_config import RuntimeSettings +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -68,7 +66,7 @@ def __init__( requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, - runtime_settings: Optional["RuntimeSettings"] = None, + runtime_settings: Optional[RuntimeSettings] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -139,8 +137,6 @@ def __init__( # Per-engine runtime mode controls. Defaults to ``RuntimeSettings()`` if # not supplied; the dataclass validates at ``__post_init__``. - from torch_tensorrt.runtime._runtime_config import RuntimeSettings - self._runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice @@ -277,7 +273,7 @@ def use_dynamically_allocated_resources( # --- runtime-settings dispatch ---------------------------------------- @property - def runtime_settings(self) -> "RuntimeSettings": + def runtime_settings(self) -> RuntimeSettings: """The current ``RuntimeSettings`` on this module (and its engine). This is the snapshot the ``runtime_config`` CM reads at ``__enter__`` @@ -285,7 +281,7 @@ def runtime_settings(self) -> "RuntimeSettings": """ return self._runtime_settings - def set_runtime_settings(self, rs: "RuntimeSettings") -> None: + def set_runtime_settings(self, rs: RuntimeSettings) -> None: """Apply ``RuntimeSettings`` to all TRT engines under this module. Walks ``named_modules()`` so calling on a wrapper / parent @@ -298,7 +294,7 @@ def set_runtime_settings(self, rs: "RuntimeSettings") -> None: mod._dispatch_runtime_settings_to_engine(rs) mod._runtime_settings = rs - def _dispatch_runtime_settings_to_engine(self, rs: "RuntimeSettings") -> None: + def _dispatch_runtime_settings_to_engine(self, rs: RuntimeSettings) -> None: """Backend-aware dispatch of ``update_runtime_settings(rs)`` to ``self.engine``.""" if self.engine is None: return @@ -476,8 +472,6 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: ) # RuntimeSettings are NOT serialized; restore defaults. Caller can # reapply via ``compiled.set_runtime_settings(...)`` or a CM after load. - from torch_tensorrt.runtime._runtime_config import RuntimeSettings - self._runtime_settings = RuntimeSettings() if self._use_python_runtime: from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 5fd7dc7107..82457130f8 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -1,19 +1,20 @@ from torch_tensorrt.dynamo.runtime import ( # noqa: F401 TorchTensorRTModule, ) -from torch_tensorrt.runtime._cuda_graph_strategy import set_cuda_graph_strategy from torch_tensorrt.runtime._cudagraphs import ( enable_cudagraphs, get_cudagraphs_mode, get_whole_cudagraphs_mode, set_cudagraphs_mode, ) -from torch_tensorrt.runtime._dynamic_shapes_kernel_strategy import ( - set_dynamic_shapes_kernel_strategy, -) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle, runtime_cache -from torch_tensorrt.runtime._runtime_config import RuntimeSettings, runtime_config +from torch_tensorrt.runtime._runtime_config import ( + RuntimeSettings, + runtime_config, + set_cuda_graph_strategy, + set_dynamic_shapes_kernel_strategy, +) from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_cuda_graph_strategy.py b/py/torch_tensorrt/runtime/_cuda_graph_strategy.py deleted file mode 100644 index 60c3bfa29a..0000000000 --- a/py/torch_tensorrt/runtime/_cuda_graph_strategy.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Sugar over ``runtime_config`` for the cuda-graph strategy knob.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Sequence, Union - -from torch_tensorrt.runtime._runtime_config import ( - _RuntimeConfigContextManager, - runtime_config, -) - -if TYPE_CHECKING: - import torch - - -def set_cuda_graph_strategy( - target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], - strategy: str, -) -> _RuntimeConfigContextManager: - """Context manager that sets the cuda-graph strategy on all TRT engines - under ``target_or_targets``. - - Accepts ``"disabled"`` or ``"whole_graph_capture"``. Delegates to - :func:`runtime_config`. - """ - return runtime_config(target_or_targets, cuda_graph_strategy=strategy) diff --git a/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py b/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py deleted file mode 100644 index 84f8180143..0000000000 --- a/py/torch_tensorrt/runtime/_dynamic_shapes_kernel_strategy.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Sugar over ``runtime_config`` for the dynamic-shapes kernel strategy knob.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Sequence, Union - -from torch_tensorrt.runtime._runtime_config import ( - _RuntimeConfigContextManager, - runtime_config, -) - -if TYPE_CHECKING: - import torch - - -def set_dynamic_shapes_kernel_strategy( - target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], - strategy: str, -) -> _RuntimeConfigContextManager: - """Context manager that sets the dynamic-shapes kernel specialization - strategy on all TRT engines under ``target_or_targets``. - - Accepts ``"lazy"``, ``"eager"``, or ``"none"``. Delegates to - :func:`runtime_config`. - """ - return runtime_config( - target_or_targets, dynamic_shapes_kernel_specialization_strategy=strategy - ) diff --git a/py/torch_tensorrt/runtime/_runtime_config.py b/py/torch_tensorrt/runtime/_runtime_config.py index 7cab20190b..d18e3e346f 100644 --- a/py/torch_tensorrt/runtime/_runtime_config.py +++ b/py/torch_tensorrt/runtime/_runtime_config.py @@ -36,11 +36,21 @@ import dataclasses import logging -from dataclasses import dataclass +import warnings +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import torch from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import ( + CUDA_GRAPH_STRATEGY, + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + RUNTIME_CACHE_PATH, +) + +# Single-shot guard for the non-RTX construction warning -- emit once per +# process, not once per RuntimeSettings instance. +_NON_RTX_WARNING_EMITTED = False if TYPE_CHECKING: from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle @@ -82,9 +92,13 @@ class RuntimeSettings: by identity (same handle ⇒ same cache). """ - dynamic_shapes_kernel_specialization_strategy: str = "lazy" - cuda_graph_strategy: str = "disabled" - runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = None # noqa: F821 + dynamic_shapes_kernel_specialization_strategy: str = ( + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY + ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY + runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = field( # noqa: F821 + default_factory=lambda: RUNTIME_CACHE_PATH + ) def __post_init__(self) -> None: if ( @@ -101,6 +115,18 @@ def __post_init__(self) -> None: f"Invalid cuda_graph_strategy: {self.cuda_graph_strategy!r}. " f"Expected one of {list(_CUDA_GRAPH_STRATEGY_MAP)}." ) + # RuntimeSettings only takes effect on TRT-RTX builds. Warn once per + # process on regular TRT so users don't silently expect cache / + # strategy plumbing to do anything. + global _NON_RTX_WARNING_EMITTED + if not ENABLED_FEATURES.tensorrt_rtx and not _NON_RTX_WARNING_EMITTED: + warnings.warn( + "RuntimeSettings is only honored on TRT-RTX builds; " + "constructing it on regular TensorRT has no effect.", + UserWarning, + stacklevel=2, + ) + _NON_RTX_WARNING_EMITTED = True def merge(self, **overrides: Any) -> "RuntimeSettings": """Return a new ``RuntimeSettings`` with ``overrides`` applied on top of self.""" @@ -386,3 +412,36 @@ def runtime_config( by-reference -- same object the caller passed in. """ return _RuntimeConfigContextManager(target_or_targets, **overrides) + + +# --------------------------------------------------------------------------- +# Sugar wrappers for the two strategy knobs +# --------------------------------------------------------------------------- + + +def set_dynamic_shapes_kernel_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the dynamic-shapes kernel specialization + strategy on all TRT engines under ``target_or_targets``. + + Accepts ``"lazy"``, ``"eager"``, or ``"none"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config( + target_or_targets, dynamic_shapes_kernel_specialization_strategy=strategy + ) + + +def set_cuda_graph_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the cuda-graph strategy on all TRT engines + under ``target_or_targets``. + + Accepts ``"disabled"`` or ``"whole_graph_capture"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config(target_or_targets, cuda_graph_strategy=strategy) diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 57d2ec6a1d..b3aff86f99 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -107,11 +107,15 @@ def test_context_created_successfully(self): engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine.context) - def test_no_implicit_cache_handle_by_default(self): - """Default RuntimeSettings has no disk-backing => no implicit handle.""" + def test_default_uses_temp_path_implicit_handle(self): + """Default RuntimeSettings points runtime_cache at the per-user temp file + (see _defaults.RUNTIME_CACHE_PATH); the engine creates an implicit handle.""" + from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH + compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertIsNone(engine._implicit_cache_handle) + self.assertIsNotNone(engine._implicit_cache_handle) + self.assertEqual(engine._implicit_cache_handle.path, RUNTIME_CACHE_PATH) def test_implicit_cache_handle_for_path_hint(self): """Passing a path string in RuntimeSettings.runtime_cache creates an implicit handle.""" diff --git a/tests/py/dynamo/runtime/test_004_runtime_settings.py b/tests/py/dynamo/runtime/test_004_runtime_settings.py index b5ff008c27..a08460ca38 100644 --- a/tests/py/dynamo/runtime/test_004_runtime_settings.py +++ b/tests/py/dynamo/runtime/test_004_runtime_settings.py @@ -54,10 +54,13 @@ class TestRuntimeSettingsDataModel(TestCase): """Pure dataclass behavior; no engine compile required.""" def test_defaults_are_valid(self): + from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH + rs = RuntimeSettings() self.assertEqual(rs.dynamic_shapes_kernel_specialization_strategy, "lazy") self.assertEqual(rs.cuda_graph_strategy, "disabled") - self.assertIsNone(rs.runtime_cache) + # Defaults to the per-user temp path from _defaults.py (mirrors ENGINE_CACHE_DIR). + self.assertEqual(rs.runtime_cache, RUNTIME_CACHE_PATH) def test_invalid_ds_strategy_raises_at_post_init(self): with self.assertRaises(ValueError): From 34fa610e49cf6d8bd0fd804fde37a6bfeb6fa402 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 20:55:07 -0700 Subject: [PATCH 5/7] runtime: cpp implicit-handle swap on set_runtime_settings Mirror ``TRTRuntimeConfig.set_settings`` (Python runtime) on the cpp runtime path. Previously the cpp side dropped the C++ engine's intrusive_ptr on settings change but left ``self._implicit_cache_handle`` on the ``TorchTensorRTModule`` pointing at the *old* wrapper -- the new cache had no Python autosave companion and never wrote to disk. Factor the path-string-to-torchbind-handle materialization into ``TorchTensorRTModule._materialize_cpp_implicit_handle``. Called from ``setup_engine`` and ``_dispatch_runtime_settings_to_engine`` (cpp branch); synchronously saves the prior wrapper before swap, replaces ``self._implicit_cache_handle`` with the new one, then runs ``load()`` after the C++ engine has attached the IRuntimeCache. Test: ``test_set_runtime_settings_saves_prior_cache_on_swap`` (parametrized over both runtimes). Compiles with path A; swaps to path B; asserts A is written synchronously at swap time and B is written on ``del compiled``. The walk-to-inner-module is wrapped in a helper so the loop variable doesn't outlive the call and keep the inner TRT module alive past ``del compiled`` (which would suppress the post-del autosave). All 53 tests pass (test_004 12/12, test_000 14/14, test_001 ds 14/14, test_001 cg 13/13). --- .../dynamo/runtime/_TorchTensorRTModule.py | 98 +++++++++++++------ .../dynamo/runtime/test_000_runtime_cache.py | 59 +++++++++++ 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index f53e123b31..1acbb7305e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -305,16 +305,69 @@ def _dispatch_runtime_settings_to_engine(self, rs: RuntimeSettings) -> None: self.engine.update_runtime_settings(rs) return - # C++ torchbind engine: flatten the dataclass into positional args. The - # cache field is converted to a torchbind RuntimeCacheHandle (or None). + # C++ torchbind engine: re-materialize the Python-side implicit cache + # wrapper before dispatch (saving the prior wrapper synchronously, to + # mirror ``TRTRuntimeConfig.set_settings`` on the Python runtime). + rs_for_dispatch, needs_load = self._materialize_cpp_implicit_handle(rs) + from torch_tensorrt.runtime._runtime_cache import _to_torchbind_handle - cache_arg = _to_torchbind_handle(rs.runtime_cache) + cache_arg = _to_torchbind_handle(rs_for_dispatch.runtime_cache) self.engine.update_runtime_settings( - rs.dynamic_shapes_kernel_specialization_strategy, - rs.cuda_graph_strategy, + rs_for_dispatch.dynamic_shapes_kernel_specialization_strategy, + rs_for_dispatch.cuda_graph_strategy, cache_arg, ) + # The C++ engine's `update_runtime_settings` materializes the + # IRuntimeCache inside the torchbind handle. Load the on-disk bytes + # (filelocked) so the new cache starts warm. + if needs_load and self._implicit_cache_handle is not None: + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.debug(f"Failed to load implicit runtime cache: {e}") + + def _materialize_cpp_implicit_handle( + self, rs: RuntimeSettings + ) -> Tuple[RuntimeSettings, bool]: + """Mirror of ``TRTRuntimeConfig._apply_settings`` for the cpp engine path. + + When ``rs.runtime_cache`` is a path string, builds a torchbind + ``RuntimeCacheHandle`` + Python wrapper and stashes the wrapper on + ``self._implicit_cache_handle`` so its ``__del__`` saves the cache + on engine destruction. Synchronously saves any prior implicit wrapper + before replacing it -- matches the explicit-save semantic in + :py:meth:`TRTRuntimeConfig.set_settings` so a settings swap doesn't + silently drop kernels JIT-compiled into the prior cache. + + Returns ``(rs_for_dispatch, needs_load)``: ``rs_for_dispatch`` has the + path string replaced with the torchbind handle (so dispatch passes the + same handle through to the C++ engine); ``needs_load`` indicates the + caller should call ``self._implicit_cache_handle.load()`` after the + engine has materialized its IRuntimeCache via the handle. + """ + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + old = self._implicit_cache_handle # type: ignore[has-type] + rc = rs.runtime_cache + if isinstance(rc, str) and rc: + tb = torch.classes.tensorrt.RuntimeCacheHandle(rc) + new = RuntimeCacheHandle(path=rc, autosave_on_del=True, torchbind_handle=tb) + self._implicit_cache_handle = new + rs_for_dispatch = rs.merge(runtime_cache=tb) + needs_load = True + else: + self._implicit_cache_handle = None + rs_for_dispatch = rs + needs_load = False + if old is not None and old is not self._implicit_cache_handle: + try: + old.save() + except Exception as e: + logger.warning( + f"Failed to save prior implicit runtime cache on swap: {e}" + ) + return rs_for_dispatch, needs_load def setup_engine(self) -> None: """ @@ -340,33 +393,18 @@ def setup_engine(self) -> None: else: self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) self.execute_engine_op = torch.ops.tensorrt.execute_engine - # If the compile-time hint was a path string, materialize a - # torchbind RuntimeCacheHandle here and wrap it in a Python - # ``RuntimeCacheHandle`` so the handle's own ``__del__`` saves - # the on-disk cache when the module is collected. (The C++ engine - # has no Python ``__del__``; file I/O lives on the Python side.) - # Substitute the torchbind handle for the string in the settings - # so dispatch passes the same handle through to TorchBind. - rc = self._runtime_settings.runtime_cache - if isinstance(rc, str) and rc: - from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle - - tb = torch.classes.tensorrt.RuntimeCacheHandle(rc) - self._implicit_cache_handle: Any = RuntimeCacheHandle( - path=rc, autosave_on_del=True, torchbind_handle=tb - ) - self._runtime_settings = self._runtime_settings.merge(runtime_cache=tb) - else: - self._implicit_cache_handle = None - # Apply runtime settings to the C++ engine (no-op if defaults). + # ``_dispatch_runtime_settings_to_engine`` (cpp branch) handles the + # str-path → torchbind handle + Python wrapper materialization, the + # dispatch to the C++ engine, AND the on-disk load. Initialize the + # attribute first so the helper's access is well-defined. + self._implicit_cache_handle: Any = None self._dispatch_runtime_settings_to_engine(self._runtime_settings) - # After dispatch the IRuntimeCache exists inside the handle; load - # the on-disk bytes (filelocked) so they're picked up on first run. + # Reflect the substituted handle back onto self._runtime_settings + # so subsequent reads (CM snapshot/restore) carry the same object. if self._implicit_cache_handle is not None: - try: - self._implicit_cache_handle.load() - except Exception as e: - logger.debug(f"Failed to load implicit runtime cache: {e}") + self._runtime_settings = self._runtime_settings.merge( + runtime_cache=self._implicit_cache_handle._torchbind + ) # requires_native_multidevice is set by the C++ constructor from the serialized REQUIRES_NATIVE_MULTIDEVICE_IDX field. if self.engine.requires_native_multidevice: diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index b3aff86f99..6dd98f630a 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -154,6 +154,65 @@ def test_cache_saved_on_del(self, _name, use_python_runtime): f"Implicit cache handle should have saved to {path} on engine __del__", ) + @parameterized.expand(_RUNTIMES) + def test_set_runtime_settings_saves_prior_cache_on_swap( + self, _name, use_python_runtime + ): + """Re-pointing ``runtime_cache`` via ``set_runtime_settings`` must save + the prior implicit cache before swapping. Mirrors the explicit-save + semantic in :py:meth:`TRTRuntimeConfig.set_settings` for the Python + runtime; the C++ runtime path replicates it via + :py:meth:`TorchTensorRTModule._materialize_cpp_implicit_handle`. + """ + _skip_if_cpp_unavailable(self, use_python_runtime) + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + with tempfile.TemporaryDirectory() as tmp: + path_a = os.path.join(tmp, "cache_A.bin") + path_b = os.path.join(tmp, "cache_B.bin") + model, inputs = _fresh_conv_model_and_inputs(seed=42) + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=path_a, + ) + _ = compiled(*inputs) + # Sanity: no file written yet (nothing has saved). + self.assertFalse(os.path.exists(path_a)) + + # Walk to the inner TorchTensorRTModule(s) and swap the cache path + # directly -- the outer GraphModule doesn't carry `set_runtime_settings`, + # and we want a *permanent* swap (the runtime_config CM restores on + # exit, which would mask the save-on-swap signal we're after). The + # walk is wrapped in a helper so the loop variable doesn't outlive + # the call and keep the inner module alive past ``del compiled``. + def _swap_all(target: torch.nn.Module, new_rs: RuntimeSettings) -> int: + count = 0 + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule): + mod.set_runtime_settings(new_rs) + count += 1 + return count + + swapped = _swap_all(compiled, RuntimeSettings(runtime_cache=path_b)) + self.assertGreater(swapped, 0, "Expected at least one TorchTensorRTModule") + self.assertTrue( + os.path.exists(path_a), + f"Prior implicit cache should have been saved to {path_a} " + "synchronously on set_runtime_settings swap", + ) + _ = compiled(*inputs) + del compiled + gc.collect() + self.assertTrue( + os.path.exists(path_b), + f"New implicit cache should have saved to {path_b} on engine " + "__del__ after the swap", + ) + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, From 111cdb2ebdc162c3a56e686783cf4c93e553b80f Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 3 Jun 2026 21:40:01 -0700 Subject: [PATCH 6/7] runtime: PR review feedback round 2 on RuntimeSettings + RuntimeCacheHandle C++-side cleanup spurred by review comments on tp5uiuc/TensorRT#3: - Convert ``RuntimeCacheHandle`` from a class with a private ``path_`` field + accessor methods (``path()`` / ``set_path()``) to a struct with a public ``path`` field. Re-register the torchbind binding via ``.def_readwrite("path", &RuntimeCacheHandle::path)``. - Move the bodies of ``serialize``, ``deserialize``, and ``has_cache`` out of the JIT-binding registration file (``register_jit_hooks.cpp``) and into member functions implemented in ``RuntimeSettings.cpp``. The ``#ifdef TRT_MAJOR_RTX`` guards live inside those impls; the registration file is preprocessor-free for these bindings. - Use ``std::tie`` in ``RuntimeSettings::operator==`` for cleaner field-wise comparison (raw ``intrusive_ptr::get()`` results hoisted to lvalues to satisfy ``std::tie``'s reference requirement). - Drop ``RuntimeSettings::merge``. C++ ``RuntimeSettings`` is now value-typed end-to-end; direct field assignment is the idiom. No callers used ``merge`` outside its own definition. No behavior change. Python-side ``RuntimeCacheHandle`` wrapper and the runtime test suite are unaffected. --- core/runtime/RuntimeSettings.cpp | 69 +++++++++++++++++++++++------ core/runtime/RuntimeSettings.h | 61 +++++++++++++------------ core/runtime/register_jit_hooks.cpp | 62 ++++---------------------- 3 files changed, 97 insertions(+), 95 deletions(-) diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp index 437d74f12d..0d2aeda533 100644 --- a/core/runtime/RuntimeSettings.cpp +++ b/core/runtime/RuntimeSettings.cpp @@ -1,26 +1,69 @@ #include "core/runtime/RuntimeSettings.h" +#include #include +#include + +#include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace runtime { -bool RuntimeSettings::operator==(RuntimeSettings const& other) const noexcept { - // Same handle pointer counts as identical cache; passing the same handle twice - // through update_runtime_settings is a no-op. - return dynamic_shapes_kernel_specialization_strategy == other.dynamic_shapes_kernel_specialization_strategy && - cuda_graph_strategy == other.cuda_graph_strategy && runtime_cache.get() == other.runtime_cache.get(); +// ---- RuntimeCacheHandle methods --------------------------------------------- +// +// The ``#ifdef TRT_MAJOR_RTX`` is intentionally confined to this translation +// unit: the public header advertises a uniform interface (always-callable +// methods that simply degrade to no-ops on non-RTX builds), and the JIT-binding +// registration file (``register_jit_hooks.cpp``) calls these as plain member +// references with zero conditional compilation. + +at::Tensor RuntimeCacheHandle::serialize() const { + auto opts = at::TensorOptions().dtype(at::kByte); +#ifdef TRT_MAJOR_RTX + if (!cache) { + return at::empty({0}, opts); + } + auto host_mem = make_trt(cache->serialize()); + if (!host_mem) { + return at::empty({0}, opts); + } + auto tensor = at::empty({static_cast(host_mem->size())}, opts); + std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size()); + return tensor; +#else + return at::empty({0}, opts); +#endif } -RuntimeSettings RuntimeSettings::merge(RuntimeSettings const& override) const { - RuntimeSettings result = *this; - result.dynamic_shapes_kernel_specialization_strategy = override.dynamic_shapes_kernel_specialization_strategy; - result.cuda_graph_strategy = override.cuda_graph_strategy; - if (override.runtime_cache) { - result.runtime_cache = override.runtime_cache; +void RuntimeCacheHandle::deserialize(TORCHTRT_UNUSED at::Tensor data) { +#ifdef TRT_MAJOR_RTX + if (data.numel() == 0 || !cache) { + return; } - return result; + auto contig = data.contiguous().to(at::kCPU); + cache->deserialize(contig.data_ptr(), static_cast(contig.numel())); +#endif +} + +bool RuntimeCacheHandle::has_cache() const { +#ifdef TRT_MAJOR_RTX + return cache != nullptr; +#else + return false; +#endif +} + +// ---- RuntimeSettings methods ------------------------------------------------ + +bool RuntimeSettings::operator==(RuntimeSettings const& other) const noexcept { + // ``runtime_cache`` compares by pointer identity: passing the same handle + // twice through ``update_runtime_settings`` is a no-op. Hoisted into locals + // because ``std::tie`` requires lvalues. + auto* this_cache = runtime_cache.get(); + auto* other_cache = other.runtime_cache.get(); + return std::tie(dynamic_shapes_kernel_specialization_strategy, cuda_graph_strategy, this_cache) == + std::tie(other.dynamic_shapes_kernel_specialization_strategy, other.cuda_graph_strategy, other_cache); } std::string RuntimeSettings::to_str() const { @@ -28,7 +71,7 @@ std::string RuntimeSettings::to_str() const { os << "Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_specialization_strategy << std::endl; os << "CUDA Graph Strategy: " << cuda_graph_strategy << std::endl; if (runtime_cache) { - auto p = runtime_cache->path(); + auto const& p = runtime_cache->path; os << "Runtime Cache: " << (p.empty() ? "" : p) << std::endl; } else { os << "Runtime Cache: " << std::endl; diff --git a/core/runtime/RuntimeSettings.h b/core/runtime/RuntimeSettings.h index 544e4de63f..230626df92 100644 --- a/core/runtime/RuntimeSettings.h +++ b/core/runtime/RuntimeSettings.h @@ -4,6 +4,7 @@ #include #include +#include "ATen/core/Tensor.h" #include "ATen/core/ivalue.h" #include "NvInfer.h" #include "torch/custom_class.h" @@ -12,41 +13,50 @@ namespace torch_tensorrt { namespace core { namespace runtime { -// A passive wrapper around an `IRuntimeCache`. Registered as a torchbind class so -// it can be passed by `c10::intrusive_ptr` across the Python/C++ boundary; the -// same handle gives both runtimes the same underlying `IRuntimeCache*`. +// A passive wrapper around an ``IRuntimeCache``. Registered as a torchbind class +// so it can be passed by ``c10::intrusive_ptr`` across the Python/C++ boundary; +// the same handle gives both runtimes the same underlying ``IRuntimeCache*``. // -// File I/O lives exclusively on the Python side (filelock + serialize/deserialize -// via `trt.IRuntimeCache`). The C++ class is purely a holder; `path` is -// informational and is not consulted by the C++ runtime. -class RuntimeCacheHandle : public torch::CustomClassHolder { - public: - explicit RuntimeCacheHandle(std::string path = "") : path_(std::move(path)) {} - - [[nodiscard]] std::string path() const { - return path_; - } - void set_path(std::string p) { - path_ = std::move(p); - } +// File I/O lives on the Python side (filelock + on-disk persistence via +// the ``serialize`` / ``deserialize`` members below). The C++ struct is purely +// a holder; ``path`` is informational and is not consulted by the C++ runtime. +struct RuntimeCacheHandle : public torch::CustomClassHolder { + std::string path; #ifdef TRT_MAJOR_RTX // The actual TensorRT runtime cache. The first engine that attaches this handle - // materializes it via `IRuntimeConfig::createRuntimeCache()` and writes the + // materializes it via ``IRuntimeConfig::createRuntimeCache()`` and writes the // shared_ptr here; subsequent engines reuse the same pointer for true sharing. std::shared_ptr cache; #endif - private: - std::string path_; + explicit RuntimeCacheHandle(std::string p = "") : path(std::move(p)) {} + + // Expose the underlying ``IRuntimeCache`` bytes for the Python side to persist + // under filelock. Returns an empty uint8 tensor when no cache is attached, or + // on non-RTX builds. + // + // ``at::Tensor`` is used (rather than ``std::string``) because TorchBind + // forces ``std::string`` to round-trip through Python ``str`` (UTF-8), and + // serialized cache bytes are not valid UTF-8. + [[nodiscard]] at::Tensor serialize() const; + + // Inverse of ``serialize``. Expects a uint8 ``at::Tensor``. No-op for empty + // input, when the underlying ``IRuntimeCache`` has not been materialized yet, + // or on non-RTX builds. + void deserialize(at::Tensor data); + + // True iff an engine has populated the underlying ``IRuntimeCache``. + // Always false on non-RTX builds. + [[nodiscard]] bool has_cache() const; }; // Per-engine runtime-only knobs sampled at IExecutionContext creation. // -// `RuntimeSettings` is a plain struct (not a torchbind class) because we flatten -// it into positional args at the torchbind boundary -- TorchBind can't carry a -// dataclass natively. Equality is value-by-value; the cache field compares -// by pointer identity (same handle -> same cache). +// ``RuntimeSettings`` is a plain struct (not a torchbind class) because we +// flatten it into positional args at the torchbind boundary -- TorchBind can't +// carry a dataclass natively. Equality is value-by-value; the cache field +// compares by pointer identity (same handle -> same cache). struct RuntimeSettings { std::string dynamic_shapes_kernel_specialization_strategy = "lazy"; std::string cuda_graph_strategy = "disabled"; @@ -57,11 +67,6 @@ struct RuntimeSettings { return !(*this == other); } - // Apply `override`'s non-default fields on top of *this*, returning a new value. - // For non-default detection on the strategy strings we always overlay; the cache - // pointer is overlaid iff `override.runtime_cache` is non-null. - RuntimeSettings merge(RuntimeSettings const& override) const; - [[nodiscard]] std::string to_str() const; }; diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 02bf244998..23ce228d77 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -17,63 +17,17 @@ namespace { // Register `RuntimeCacheHandle` as a torchbind class so Python can pass the same // underlying `IRuntimeCache` to both Python and C++ engine backends. File I/O on -// the handle is the Python side's responsibility; the C++ class only holds the -// shared_ptr and an informational path string. +// the handle is the Python side's responsibility; the C++ struct only holds the +// shared_ptr and an informational path string. The method bodies (and the +// `#ifdef TRT_MAJOR_RTX` they entail) live in RuntimeSettings.cpp -- this file +// is registration-only. static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = torch::class_("tensorrt", "RuntimeCacheHandle") .def(torch::init()) - .def("path", &RuntimeCacheHandle::path) - .def("set_path", &RuntimeCacheHandle::set_path) - // Expose the underlying IRuntimeCache bytes to Python so the Python- - // side save/load logic can persist them under filelock. Returns an - // empty uint8 tensor if the cache hasn't been materialized yet. - // - // We return ``at::Tensor`` rather than ``std::string`` because TorchBind - // forces ``std::string`` to round-trip through Python ``str`` (UTF-8) - // and serialized cache bytes are not valid UTF-8. - .def( - "serialize", - [](const c10::intrusive_ptr& self) -> at::Tensor { -#ifdef TRT_MAJOR_RTX - auto opts = at::TensorOptions().dtype(at::kByte); - if (!self->cache) { - return at::empty({0}, opts); - } - auto host_mem = make_trt(self->cache->serialize()); - if (!host_mem) { - return at::empty({0}, opts); - } - auto tensor = at::empty({static_cast(host_mem->size())}, opts); - std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size()); - return tensor; -#else - return at::empty({0}, at::TensorOptions().dtype(at::kByte)); -#endif - }) - // Deserialize bytes loaded from disk into the underlying IRuntimeCache. - // Expects a uint8 ``at::Tensor``. No-op for empty input or if the - // IRuntimeCache hasn't been materialized yet. - .def( - "deserialize", - [](const c10::intrusive_ptr& self, at::Tensor data) -> void { -#ifdef TRT_MAJOR_RTX - if (data.numel() == 0 || !self->cache) { - return; - } - auto contig = data.contiguous().to(at::kCPU); - self->cache->deserialize(contig.data_ptr(), static_cast(contig.numel())); -#else - (void)data; -#endif - }) - // True iff an engine has populated the underlying IRuntimeCache. - .def("has_cache", [](const c10::intrusive_ptr& self) -> bool { -#ifdef TRT_MAJOR_RTX - return self->cache != nullptr; -#else - return false; -#endif - }); + .def_readwrite("path", &RuntimeCacheHandle::path) + .def("serialize", &RuntimeCacheHandle::serialize) + .def("deserialize", &RuntimeCacheHandle::deserialize) + .def("has_cache", &RuntimeCacheHandle::has_cache); // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { From 363d20b1ede72593d1f47e4c5c2491c9c59e4133 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 4 Jun 2026 01:49:10 -0700 Subject: [PATCH 7/7] runtime: lazy IExecutionContext creation in cpp TRTEngine Defer the TRT ``createExecutionContext`` call -- the most expensive part of engine setup on TRT-RTX, since it JIT-compiles the specialized kernel set -- until first use. Collapses the historical "ctor create with defaults + post-construction recreate with user settings" pair on the ``setup_engine`` cpp branch into a single create. C++: - ``TRTEngine::ensure_execution_context()`` -- idempotent lazy build via ``runtime_cfg.create_execution_context``. Called from ``execute_engine``, ``infer_outputs``, ``enable_profiling``, ``bind_nccl_comm``. - ``TRTEngine::invalidate_execution_context()`` -- ``exec_ctx.reset()``. ``update_runtime_settings``, ``set_resource_allocation_strategy``, ``disable_profiling``, and ``set_device_memory_budget`` now invalidate without immediately recreating; the next user lazy-creates. - Ctor: drop the eager ``recreate_execution_context()`` call. The two conditional in-window users (``enable_profiling`` debug build and ``bind_nccl_comm`` distributed) ensure-first on their own. - ``to_str()`` guards on a null ``exec_ctx`` and reports ```` instead of dereferencing. - ``recreate_execution_context()`` bumps a ``num_execution_contexts_created_`` counter, exposed as a torchbind method for tests. Python: - Mirror the counter on the Python runtime ``TRTEngine`` (``num_execution_contexts_created()``) for cross-runtime test coverage. - ``TorchTensorRTModule._materialize_cpp_implicit_handle`` reuses the prior wrapper when the path string is unchanged, instead of always creating a fresh torchbind handle. Without this the cpp ``set_settings`` would see a different ``runtime_cache.get()`` pointer on every (otherwise identical) call and unnecessarily invalidate the context. Tests: - ``test_004_runtime_settings.py::TestLazyExecutionContextCreation`` (4 tests, parametrized python/cpp = 8 cases). Asserts: single create per engine setup on both runtimes regardless of default vs compile-time RuntimeSettings, lazy recreate semantics after a settings flip, and zero-recreate on no-op settings re-application. All 61 runtime tests pass. --- core/runtime/TRTEngine.cpp | 122 +++++++++++----- core/runtime/TRTEngine.h | 24 ++++ core/runtime/execute_engine.cpp | 5 + core/runtime/register_jit_hooks.cpp | 1 + .../dynamo/runtime/_TRTEngine.py | 13 ++ .../dynamo/runtime/_TorchTensorRTModule.py | 12 ++ .../runtime/test_004_runtime_settings.py | 134 ++++++++++++++++++ 7 files changed, 272 insertions(+), 39 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index bf0fb81897..e5d64ac8f7 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -185,7 +185,12 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - recreate_execution_context(); + // ``exec_ctx`` is created lazily on first use (``execute_engine``, + // ``enable_profiling``, ``bind_nccl_comm``, ``infer_outputs``, ``to_str``). + // Deferring here lets the Python ``setup_engine`` cpp branch dispatch user + // ``RuntimeSettings`` before the (expensive, kernel-JIT-compiling) TRT + // ``createExecutionContext`` call -- collapses the historical + // "create-with-defaults then recreate-with-settings" pair into a single create. // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -276,17 +281,18 @@ TRTEngine::TRTEngine( has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); #ifndef NDEBUG + // Debug builds want profiling on from the start; that requires a live ctx. this->enable_profiling(); #endif LOG_DEBUG(*this); #ifdef ENABLE_TRT_NCCL_COLLECTIVES - // Attempt to bind the NCCL communicator immediately after exec_ctx is ready. - // This handles the common case where dist.init_process_group() and an initial - // collective have already been called before the engine is constructed. - // If the communicator isn't available yet (e.g. engine constructed before the - // first collective), bind_nccl_comm returns false and execute_engine() will - // retry on its first invocation. + // Distributed engines must have a bound communicator on the IExecutionContext + // before the first collective; bind here. ``bind_nccl_comm`` lazily creates + // ``exec_ctx`` via ``ensure_execution_context`` if it hasn't been built yet. + // For non-distributed engines we leave ``exec_ctx`` null so the first + // ``execute_engine`` (typically right after the Python settings dispatch) + // is the single TRT context-create site. if (this->requires_native_multidevice) { bind_nccl_comm(); } @@ -310,7 +316,10 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - recreate_execution_context(); + // Drop the profiler-attached context; next execute lazily creates a fresh + // one with no profiler. (TRT has no detach-profiler API -- recreate is the + // canonical way.) + invalidate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -331,6 +340,10 @@ void TRTEngine::dump_engine_layer_info() { void TRTEngine::enable_profiling() { profile_execution = true; trt_engine_profiler = std::make_unique(name); + // ``setProfiler`` requires a live ``IExecutionContext``; under the lazy-create + // policy the ctx may be null when the user toggles profiling before the first + // execute. + ensure_execution_context(); exec_ctx->setProfiler(trt_engine_profiler.get()); } @@ -362,6 +375,8 @@ std::string TRTEngine::get_serialized_metadata() { } std::vector TRTEngine::infer_outputs(std::vector> input_shapes) { + // Lazy-create: callers can hit this before the first execute_engine. + ensure_execution_context(); std::vector outputs; TORCHTRT_CHECK( (in_binding_names.size() == input_shapes.size()), @@ -399,21 +414,21 @@ int64_t TRTEngine::get_device_memory_budget() { } bool TRTEngine::set_device_memory_budget(int64_t budget) { - // Recreating the context because weight streaming budget cannot be modified while there are active context. - if (exec_ctx.get() != nullptr) { - exec_ctx.reset(); - } + // Weight-streaming budget cannot be modified while a context is live; drop it. + invalidate_execution_context(); if (profile_execution) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - recreate_execution_context(); + // Eagerly rebuild if the user had profiling on (so the profiler is attached + // before they query it); otherwise leave lazy. if (profile_execution) { enable_profiling(); } #ifdef ENABLE_TRT_NCCL_COLLECTIVES - // exec_ctx was recreated — re-bind the NCCL communicator if this is a - // distributed engine that has already been set up. + // exec_ctx was invalidated — re-bind the NCCL communicator if this is a + // distributed engine that has already been set up. ``bind_nccl_comm`` ensures + // the context before binding. if (nccl_initialized) { bind_nccl_comm(); } @@ -438,27 +453,35 @@ std::string TRTEngine::to_str() const { std::stringstream ss; ss << "Torch-TensorRT TensorRT Engine:" << std::endl; ss << " Name: " << name << std::endl; - ss << " Inputs: [" << std::endl; - for (uint64_t i = 0; i < num_io.first; i++) { - ss << " id: " << i << std::endl; - ss << " name: " << in_binding_names[i].c_str() << std::endl; - ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl; - ss << " dtype: " - << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str())) - << std::endl; - } - ss << " ]" << std::endl; - ss << " Outputs: [" << std::endl; - for (uint64_t o = 0; o < num_io.second; o++) { - ss << " id: " << o << std::endl; - ss << " name: " << out_binding_names[o].c_str() << std::endl; - ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl; - ss << " dtype: " - << util::TRTDataTypeToScalarType( - exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str())) - << std::endl; + // Shape/dtype queries require a live IExecutionContext. Under the lazy-create + // policy ``exec_ctx`` may be null at debug-log time (e.g. ctor-time LOG_DEBUG + // before the first execute_engine). Fall back to a marker. + if (exec_ctx == nullptr) { + ss << " Inputs: " << std::endl; + ss << " Outputs: " << std::endl; + } else { + ss << " Inputs: [" << std::endl; + for (uint64_t i = 0; i < num_io.first; i++) { + ss << " id: " << i << std::endl; + ss << " name: " << in_binding_names[i].c_str() << std::endl; + ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl; + ss << " dtype: " + << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str())) + << std::endl; + } + ss << " ]" << std::endl; + ss << " Outputs: [" << std::endl; + for (uint64_t o = 0; o < num_io.second; o++) { + ss << " id: " << o << std::endl; + ss << " name: " << out_binding_names[o].c_str() << std::endl; + ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl; + ss << " dtype: " + << util::TRTDataTypeToScalarType( + exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str())) + << std::endl; + } + ss << " ]" << std::endl; } - ss << " ]" << std::endl; ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; @@ -554,7 +577,7 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt "Setting resource allocation strategy to " << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" : "static")); - recreate_execution_context(); + invalidate_execution_context(); } } @@ -637,6 +660,9 @@ bool TRTEngine::bind_nccl_comm() { return false; } + // Distributed engines must hold a live IExecutionContext at bind time. + // Under the lazy-create policy this is the first call site that needs it. + ensure_execution_context(); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Cannot bind NCCL communicator: execution context is null"); exec_ctx->setCommunicator(reinterpret_cast(comm_ptr)); this->nccl_initialized = true; @@ -650,8 +676,10 @@ void TRTEngine::release_nccl_comm() { } LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'"); torch::cuda::synchronize(device_info.id); - this->exec_ctx.reset(); - recreate_execution_context(); + invalidate_execution_context(); + // Eagerly rebuild so the engine returns to a "context-live, no NCCL" state + // (callers may immediately query exec_ctx for shape/dtype info post-release). + ensure_execution_context(); this->nccl_initialized = false; LOG_INFO("NCCL communicator released from engine '" << this->name << "'"); } @@ -679,19 +707,35 @@ void TRTEngine::update_runtime_settings(RuntimeSettings new_settings) { if (!runtime_cfg.set_settings(std::move(new_settings))) { return; } - recreate_execution_context(); + // Lazy: drop the live context, but do NOT eagerly recreate. The next user + // (typically the next ``execute_engine`` call) will lazy-create with the + // new settings via ``ensure_execution_context``. This collapses the + // historical "ctor-create-with-defaults + dispatch-recreate-with-settings" + // pair on the Python ``setup_engine`` cpp branch into a single create. + invalidate_execution_context(); // Existing recreate sites set runtime_states.context_changed for cudagraph // re-record; do the same here so a settings flip inside an active CM forces // the next enqueue to re-record any captured graph. runtime_states.context_changed = true; } +void TRTEngine::ensure_execution_context() { + if (exec_ctx == nullptr) { + recreate_execution_context(); + } +} + +void TRTEngine::invalidate_execution_context() noexcept { + exec_ctx.reset(); +} + void TRTEngine::recreate_execution_context() { const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); + ++num_execution_contexts_created_; } } // namespace runtime diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 9151bec374..1f1f876826 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -307,9 +307,33 @@ struct TRTEngine : torch::CustomClassHolder { // already disabled. void disable_rtx_native_cudagraphs(); + // Materialize ``exec_ctx`` if it is currently null, using the current settings + // from ``runtime_cfg``. Idempotent: a non-null ``exec_ctx`` is left untouched. + // Called from every site that needs the live context (``execute_engine``, + // ``enable_profiling``, ``bind_nccl_comm``, ``infer_outputs``, etc.). + void ensure_execution_context(); + + // Drop the live ``exec_ctx`` without recreating. The next ``ensure_execution_context`` + // (typically inside the next ``execute_engine`` call) will rebuild from the + // current ``runtime_cfg`` settings. + void invalidate_execution_context() noexcept; + + // Test/observability hook: increments once every time ``runtime_cfg.create_execution_context`` + // is invoked (i.e. an actual TRT createExecutionContext call, which on RTX + // also JIT-compiles the specialized kernel set). Bound on the torchbind class. + // ``noexcept`` is intentionally omitted -- torchbind's ``def`` template is + // not specialized for ``const noexcept`` member functions. + [[nodiscard]] int64_t num_execution_contexts_created() const { + return num_execution_contexts_created_; + } + private: // Single entry point that (re)creates exec_ctx via runtime_cfg.create_execution_context. + // Bumps ``num_execution_contexts_created_``. Callers should normally go through + // ``ensure_execution_context`` for the lazy semantics. void recreate_execution_context(); + + int64_t num_execution_contexts_created_ = 0; }; } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 80936951ef..83667a45af 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -212,6 +212,11 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + // Materialize the IExecutionContext on the first execute (under the lazy-create + // policy the ctor and ``update_runtime_settings`` no longer eagerly build one). + // Idempotent: a non-null exec_ctx is left untouched. + compiled_engine->ensure_execution_context(); + // All inputs are expected to be on CUDA. Warn and move any that are not. for (auto& inp : inputs) { if (inp.defined() && !inp.is_cuda()) { diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 23ce228d77..7afada2c21 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -56,6 +56,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) .def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned) + .def("num_execution_contexts_created", &TRTEngine::num_execution_contexts_created) .def( "use_dynamically_allocated_resources", [](const c10::intrusive_ptr& self, bool dynamic) -> void { diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 1bfd3a21e9..62d787d47a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -436,8 +436,21 @@ def _create_execution_context(self) -> trt.IExecutionContext: self.cuda_engine, alloc_strategy ) assert context is not None, "Failed to create execution context" + # Mirrors ``TRTEngine::num_execution_contexts_created`` on the C++ side; + # used by tests to assert single createExecutionContext per engine setup. + self._num_execution_contexts_created = ( + getattr(self, "_num_execution_contexts_created", 0) + 1 + ) return context + def num_execution_contexts_created(self) -> int: + """Number of TRT ``createExecutionContext`` invocations on this engine. + + Each call (re)JITs the specialized kernel set on RTX, so this is the + canonical counter for the setup-cost regression test. + """ + return getattr(self, "_num_execution_contexts_created", 0) + def _setup_engine(self) -> None: multi_gpu_device_check() self.runtime = trt.Runtime(TRT_LOGGER) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1acbb7305e..ebb4b833e6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -351,6 +351,18 @@ def _materialize_cpp_implicit_handle( old = self._implicit_cache_handle # type: ignore[has-type] rc = rs.runtime_cache if isinstance(rc, str) and rc: + # No-op fast path: if the prior wrapper is already pointed at the + # same disk path and still holds its torchbind sibling, reuse it. + # Without this the cpp ``set_settings`` sees a *different* + # ``runtime_cache.get()`` pointer every call and invalidates the + # execution context even when the user passed identical settings. + if ( + old is not None + and getattr(old, "path", None) == rc + and old._torchbind is not None + ): + rs_for_dispatch = rs.merge(runtime_cache=old._torchbind) + return rs_for_dispatch, False tb = torch.classes.tensorrt.RuntimeCacheHandle(rc) new = RuntimeCacheHandle(path=rc, autosave_on_del=True, torchbind_handle=tb) self._implicit_cache_handle = new diff --git a/tests/py/dynamo/runtime/test_004_runtime_settings.py b/tests/py/dynamo/runtime/test_004_runtime_settings.py index a08460ca38..75f3e1023f 100644 --- a/tests/py/dynamo/runtime/test_004_runtime_settings.py +++ b/tests/py/dynamo/runtime/test_004_runtime_settings.py @@ -175,5 +175,139 @@ def test_unknown_kwarg_raises(self): runtime_config(target, not_a_real_field=True) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Lazy IExecutionContext count is meaningful only on TRT-RTX", +) +class TestLazyExecutionContextCreation(TestCase): + """Regression guard: each setup creates exactly one IExecutionContext. + + On RTX, ``createExecutionContext`` JIT-compiles the specialized kernel set, + so a redundant create doubles a non-trivial chunk of setup latency. The + historical cpp-runtime path did two creates per engine setup -- one in the + torchbind ctor with defaults, one in the post-construction + ``update_runtime_settings`` dispatch. The lazy-create policy collapses these + into a single create at first execute. + """ + + def _walk_engines(self, compiled): + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule): + yield mod + + def _skip_if_cpp_unavailable(self, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + self.skipTest("C++ runtime is not available") + + @parameterized.expand(_RUNTIMES) + def test_one_context_create_with_default_settings(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + compiled = _compile_simple(use_python_runtime=use_python_runtime) + ttrt_modules = list(self._walk_engines(compiled)) + self.assertTrue(ttrt_modules, "Expected at least one TorchTensorRTModule") + # Setup itself must not have created the context yet on the cpp path + # (Python runtime engine constructs the context inside its own __init__ + # *after* settings are applied, so 1 is correct there too). + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + if use_python_runtime: + # Python runtime threads runtime_settings into the engine ctor + # directly, so the single create lives there. + self.assertEqual(n, 1, f"Python runtime expected 1 create, got {n}") + else: + # Cpp path defers until first execute. + self.assertEqual( + n, 0, f"Cpp runtime expected 0 creates at setup, got {n}" + ) + + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, 1, f"Expected exactly 1 create after first execute, got {n}" + ) + + @parameterized.expand(_RUNTIMES) + def test_one_context_create_with_compile_time_settings( + self, _name, use_python_runtime + ): + """User-passed RuntimeSettings at compile time must not double-create. + + This is the regression case the lazy-create refactor addresses on cpp: + old behaviour did ctor-create-with-defaults + dispatch-recreate. The + observable count after first execute must be 1. + """ + self._skip_if_cpp_unavailable(use_python_runtime) + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=rs, use_python_runtime=use_python_runtime + ) + ttrt_modules = list(self._walk_engines(compiled)) + self.assertTrue(ttrt_modules) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, + 1, + f"Setup + first execute must perform exactly 1 createExecutionContext; " + f"got {n} on {'python' if use_python_runtime else 'cpp'} runtime", + ) + + @parameterized.expand(_RUNTIMES) + def test_set_runtime_settings_lazy_recreate(self, _name, use_python_runtime): + """Changing settings invalidates the context but the recreate is lazy: + the count bumps on the next execute, not on the set call.""" + self._skip_if_cpp_unavailable(use_python_runtime) + compiled = _compile_simple(use_python_runtime=use_python_runtime) + ttrt_modules = list(self._walk_engines(compiled)) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + self.assertEqual(mod.engine.num_execution_contexts_created(), 1) + + new_rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + for mod in ttrt_modules: + mod.set_runtime_settings(new_rs) + # set itself does not eagerly recreate on the cpp path. + if not use_python_runtime: + self.assertEqual(mod.engine.num_execution_contexts_created(), 1) + + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, + 2, + f"Expected exactly 2 creates after settings flip + execute, got {n}", + ) + + @parameterized.expand(_RUNTIMES) + def test_no_op_settings_change_does_not_recreate(self, _name, use_python_runtime): + """Re-applying the same RuntimeSettings is a no-op: no invalidate, no + recreate, count is stable across follow-up executes.""" + self._skip_if_cpp_unavailable(use_python_runtime) + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=rs, use_python_runtime=use_python_runtime + ) + ttrt_modules = list(self._walk_engines(compiled)) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + baseline = [mod.engine.num_execution_contexts_created() for mod in ttrt_modules] + + for mod in ttrt_modules: + mod.set_runtime_settings(rs) # identical to existing + _ = compiled(*inputs) + for mod, prior in zip(ttrt_modules, baseline): + self.assertEqual(mod.engine.num_execution_contexts_created(), prior) + + if __name__ == "__main__": run_tests()