diff --git a/core/runtime/BUILD b/core/runtime/BUILD index bb9f779929..5ada005328 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -86,6 +86,7 @@ cc_library( "DeviceList.cpp", "Platform.cpp", "RTDevice.cpp", + "RuntimeSettings.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", "TRTRuntimeConfig.cpp", @@ -96,6 +97,7 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TRTRuntimeConfig.h", @@ -158,6 +160,7 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TensorRTBindingNames.h", @@ -174,6 +177,7 @@ filegroup( srcs = [ "Platform.h", "RTDevice.h", + "RuntimeSettings.h", "TRTEngine.h", "TRTEngineProfiler.h", "TRTRuntimeConfig.h", diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp new file mode 100644 index 0000000000..0d2aeda533 --- /dev/null +++ b/core/runtime/RuntimeSettings.cpp @@ -0,0 +1,89 @@ +#include "core/runtime/RuntimeSettings.h" + +#include +#include +#include + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// ---- RuntimeCacheHandle methods --------------------------------------------- +// +// The ``#ifdef TRT_MAJOR_RTX`` is intentionally confined to this translation +// unit: the public header advertises a uniform interface (always-callable +// methods that simply degrade to no-ops on non-RTX builds), and the JIT-binding +// registration file (``register_jit_hooks.cpp``) calls these as plain member +// references with zero conditional compilation. + +at::Tensor RuntimeCacheHandle::serialize() const { + auto opts = at::TensorOptions().dtype(at::kByte); +#ifdef TRT_MAJOR_RTX + if (!cache) { + return at::empty({0}, opts); + } + auto host_mem = make_trt(cache->serialize()); + if (!host_mem) { + return at::empty({0}, opts); + } + auto tensor = at::empty({static_cast(host_mem->size())}, opts); + std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size()); + return tensor; +#else + return at::empty({0}, opts); +#endif +} + +void RuntimeCacheHandle::deserialize(TORCHTRT_UNUSED at::Tensor data) { +#ifdef TRT_MAJOR_RTX + if (data.numel() == 0 || !cache) { + return; + } + auto contig = data.contiguous().to(at::kCPU); + cache->deserialize(contig.data_ptr(), static_cast(contig.numel())); +#endif +} + +bool RuntimeCacheHandle::has_cache() const { +#ifdef TRT_MAJOR_RTX + return cache != nullptr; +#else + return false; +#endif +} + +// ---- RuntimeSettings methods ------------------------------------------------ + +bool RuntimeSettings::operator==(RuntimeSettings const& other) const noexcept { + // ``runtime_cache`` compares by pointer identity: passing the same handle + // twice through ``update_runtime_settings`` is a no-op. Hoisted into locals + // because ``std::tie`` requires lvalues. + auto* this_cache = runtime_cache.get(); + auto* other_cache = other.runtime_cache.get(); + return std::tie(dynamic_shapes_kernel_specialization_strategy, cuda_graph_strategy, this_cache) == + std::tie(other.dynamic_shapes_kernel_specialization_strategy, other.cuda_graph_strategy, other_cache); +} + +std::string RuntimeSettings::to_str() const { + std::ostringstream os; + os << "Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_specialization_strategy << std::endl; + os << "CUDA Graph Strategy: " << cuda_graph_strategy << std::endl; + if (runtime_cache) { + auto const& p = runtime_cache->path; + os << "Runtime Cache: " << (p.empty() ? "" : p) << std::endl; + } else { + os << "Runtime Cache: " << std::endl; + } + return os.str(); +} + +std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs) { + os << rs.to_str(); + return os; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/RuntimeSettings.h b/core/runtime/RuntimeSettings.h new file mode 100644 index 0000000000..230626df92 --- /dev/null +++ b/core/runtime/RuntimeSettings.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include + +#include "ATen/core/Tensor.h" +#include "ATen/core/ivalue.h" +#include "NvInfer.h" +#include "torch/custom_class.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// A passive wrapper around an ``IRuntimeCache``. Registered as a torchbind class +// so it can be passed by ``c10::intrusive_ptr`` across the Python/C++ boundary; +// the same handle gives both runtimes the same underlying ``IRuntimeCache*``. +// +// File I/O lives on the Python side (filelock + on-disk persistence via +// the ``serialize`` / ``deserialize`` members below). The C++ struct is purely +// a holder; ``path`` is informational and is not consulted by the C++ runtime. +struct RuntimeCacheHandle : public torch::CustomClassHolder { + std::string path; + +#ifdef TRT_MAJOR_RTX + // The actual TensorRT runtime cache. The first engine that attaches this handle + // materializes it via ``IRuntimeConfig::createRuntimeCache()`` and writes the + // shared_ptr here; subsequent engines reuse the same pointer for true sharing. + std::shared_ptr cache; +#endif + + explicit RuntimeCacheHandle(std::string p = "") : path(std::move(p)) {} + + // Expose the underlying ``IRuntimeCache`` bytes for the Python side to persist + // under filelock. Returns an empty uint8 tensor when no cache is attached, or + // on non-RTX builds. + // + // ``at::Tensor`` is used (rather than ``std::string``) because TorchBind + // forces ``std::string`` to round-trip through Python ``str`` (UTF-8), and + // serialized cache bytes are not valid UTF-8. + [[nodiscard]] at::Tensor serialize() const; + + // Inverse of ``serialize``. Expects a uint8 ``at::Tensor``. No-op for empty + // input, when the underlying ``IRuntimeCache`` has not been materialized yet, + // or on non-RTX builds. + void deserialize(at::Tensor data); + + // True iff an engine has populated the underlying ``IRuntimeCache``. + // Always false on non-RTX builds. + [[nodiscard]] bool has_cache() const; +}; + +// Per-engine runtime-only knobs sampled at IExecutionContext creation. +// +// ``RuntimeSettings`` is a plain struct (not a torchbind class) because we +// flatten it into positional args at the torchbind boundary -- TorchBind can't +// carry a dataclass natively. Equality is value-by-value; the cache field +// compares by pointer identity (same handle -> same cache). +struct RuntimeSettings { + std::string dynamic_shapes_kernel_specialization_strategy = "lazy"; + std::string cuda_graph_strategy = "disabled"; + c10::intrusive_ptr runtime_cache = nullptr; + + bool operator==(RuntimeSettings const& other) const noexcept; + bool operator!=(RuntimeSettings const& other) const noexcept { + return !(*this == other); + } + + [[nodiscard]] std::string to_str() const; +}; + +std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 873c758854..e5d64ac8f7 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -97,7 +97,7 @@ TRTEngine::TRTEngine( bool requires_output_allocator, const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - TRTRuntimeConfig runtime_cfg) + RuntimeSettings runtime_settings) : TRTEngine( "deserialized_trt", serialized_engine, @@ -109,7 +109,7 @@ TRTEngine::TRTEngine( requires_output_allocator, serialized_metadata, resource_allocation_strategy, - std::move(runtime_cfg)) {} + std::move(runtime_settings)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -125,7 +125,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic), - make_runtime_config_from_serialized(serialized_info)) { + RuntimeSettings{}) { // Single visible marker that this engine was instantiated through the C++ runtime // entry point (i.e. torch.classes.tensorrt.Engine), distinguishing it from the Python // TRTEngine path. Tests look for this string in captured stderr to verify the @@ -148,8 +148,8 @@ TRTEngine::TRTEngine( bool requires_output_allocator, const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - TRTRuntimeConfig runtime_cfg) { - this->runtime_cfg = std::move(runtime_cfg); + RuntimeSettings runtime_settings) { + this->runtime_cfg = TRTRuntimeConfig(std::move(runtime_settings)); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -185,7 +185,12 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - recreate_execution_context(); + // ``exec_ctx`` is created lazily on first use (``execute_engine``, + // ``enable_profiling``, ``bind_nccl_comm``, ``infer_outputs``, ``to_str``). + // Deferring here lets the Python ``setup_engine`` cpp branch dispatch user + // ``RuntimeSettings`` before the (expensive, kernel-JIT-compiling) TRT + // ``createExecutionContext`` call -- collapses the historical + // "create-with-defaults then recreate-with-settings" pair into a single create. // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -273,20 +278,21 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs_size, outputs); } - runtime_cfg.has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); + has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); #ifndef NDEBUG + // Debug builds want profiling on from the start; that requires a live ctx. this->enable_profiling(); #endif LOG_DEBUG(*this); #ifdef ENABLE_TRT_NCCL_COLLECTIVES - // Attempt to bind the NCCL communicator immediately after exec_ctx is ready. - // This handles the common case where dist.init_process_group() and an initial - // collective have already been called before the engine is constructed. - // If the communicator isn't available yet (e.g. engine constructed before the - // first collective), bind_nccl_comm returns false and execute_engine() will - // retry on its first invocation. + // Distributed engines must have a bound communicator on the IExecutionContext + // before the first collective; bind here. ``bind_nccl_comm`` lazily creates + // ``exec_ctx`` via ``ensure_execution_context`` if it hasn't been built yet. + // For non-distributed engines we leave ``exec_ctx`` null so the first + // ``execute_engine`` (typically right after the Python settings dispatch) + // is the single TRT context-create site. if (this->requires_native_multidevice) { bind_nccl_comm(); } @@ -294,9 +300,9 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { - // Marked noexcept so safe to invoke from a destructor without - // explicit try/catch; any I/O error is logged internally. - runtime_cfg.save_runtime_cache(); + // Disk persistence for runtime caches is owned by the Python side + // (`RuntimeCacheHandle.save()` invoked from the runtime_cache CM or the engine + // wrapper). The C++ side just lets refcounts drop. trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -310,7 +316,10 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - recreate_execution_context(); + // Drop the profiler-attached context; next execute lazily creates a fresh + // one with no profiler. (TRT has no detach-profiler API -- recreate is the + // canonical way.) + invalidate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -331,6 +340,10 @@ void TRTEngine::dump_engine_layer_info() { void TRTEngine::enable_profiling() { profile_execution = true; trt_engine_profiler = std::make_unique(name); + // ``setProfiler`` requires a live ``IExecutionContext``; under the lazy-create + // policy the ctx may be null when the user toggles profiling before the first + // execute. + ensure_execution_context(); exec_ctx->setProfiler(trt_engine_profiler.get()); } @@ -362,6 +375,8 @@ std::string TRTEngine::get_serialized_metadata() { } std::vector TRTEngine::infer_outputs(std::vector> input_shapes) { + // Lazy-create: callers can hit this before the first execute_engine. + ensure_execution_context(); std::vector outputs; TORCHTRT_CHECK( (in_binding_names.size() == input_shapes.size()), @@ -399,21 +414,21 @@ int64_t TRTEngine::get_device_memory_budget() { } bool TRTEngine::set_device_memory_budget(int64_t budget) { - // Recreating the context because weight streaming budget cannot be modified while there are active context. - if (exec_ctx.get() != nullptr) { - exec_ctx.reset(); - } + // Weight-streaming budget cannot be modified while a context is live; drop it. + invalidate_execution_context(); if (profile_execution) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - recreate_execution_context(); + // Eagerly rebuild if the user had profiling on (so the profiler is attached + // before they query it); otherwise leave lazy. if (profile_execution) { enable_profiling(); } #ifdef ENABLE_TRT_NCCL_COLLECTIVES - // exec_ctx was recreated — re-bind the NCCL communicator if this is a - // distributed engine that has already been set up. + // exec_ctx was invalidated — re-bind the NCCL communicator if this is a + // distributed engine that has already been set up. ``bind_nccl_comm`` ensures + // the context before binding. if (nccl_initialized) { bind_nccl_comm(); } @@ -438,33 +453,41 @@ std::string TRTEngine::to_str() const { std::stringstream ss; ss << "Torch-TensorRT TensorRT Engine:" << std::endl; ss << " Name: " << name << std::endl; - ss << " Inputs: [" << std::endl; - for (uint64_t i = 0; i < num_io.first; i++) { - ss << " id: " << i << std::endl; - ss << " name: " << in_binding_names[i].c_str() << std::endl; - ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl; - ss << " dtype: " - << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str())) - << std::endl; - } - ss << " ]" << std::endl; - ss << " Outputs: [" << std::endl; - for (uint64_t o = 0; o < num_io.second; o++) { - ss << " id: " << o << std::endl; - ss << " name: " << out_binding_names[o].c_str() << std::endl; - ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl; - ss << " dtype: " - << util::TRTDataTypeToScalarType( - exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str())) - << std::endl; + // Shape/dtype queries require a live IExecutionContext. Under the lazy-create + // policy ``exec_ctx`` may be null at debug-log time (e.g. ctor-time LOG_DEBUG + // before the first execute_engine). Fall back to a marker. + if (exec_ctx == nullptr) { + ss << " Inputs: " << std::endl; + ss << " Outputs: " << std::endl; + } else { + ss << " Inputs: [" << std::endl; + for (uint64_t i = 0; i < num_io.first; i++) { + ss << " id: " << i << std::endl; + ss << " name: " << in_binding_names[i].c_str() << std::endl; + ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl; + ss << " dtype: " + << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str())) + << std::endl; + } + ss << " ]" << std::endl; + ss << " Outputs: [" << std::endl; + for (uint64_t o = 0; o < num_io.second; o++) { + ss << " id: " << o << std::endl; + ss << " name: " << out_binding_names[o].c_str() << std::endl; + ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl; + ss << " dtype: " + << util::TRTDataTypeToScalarType( + exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str())) + << std::endl; + } + ss << " ]" << std::endl; } - ss << " ]" << std::endl; ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; - ss << runtime_cfg.to_str(); + ss << runtime_cfg.settings().to_str(); // clang-format on return ss.str(); } @@ -511,11 +534,7 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]), - std::tuple("has_runtime_cfg", serialized_info[HAS_RUNTIME_CFG_IDX]), - std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), - std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])); + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); } std::vector TRTEngine::serialize() { @@ -541,17 +560,8 @@ std::vector TRTEngine::serialize() { serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; - // rank/world_size are runtime facts (may differ at load time); not serialized. -#ifdef TRT_MAJOR_RTX - serialized_info[HAS_RUNTIME_CFG_IDX] = "1"; -#else - serialized_info[HAS_RUNTIME_CFG_IDX] = "0"; -#endif - serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; - serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( - static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); - serialized_info[CUDA_GRAPH_STRATEGY_IDX] = - std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); + // RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory + // initialization values, not part of the engine's identity. See pytorch/TensorRT#4310. return serialized_info; } @@ -567,7 +577,7 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt "Setting resource allocation strategy to " << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" : "static")); - recreate_execution_context(); + invalidate_execution_context(); } } @@ -650,6 +660,9 @@ bool TRTEngine::bind_nccl_comm() { return false; } + // Distributed engines must hold a live IExecutionContext at bind time. + // Under the lazy-create policy this is the first call site that needs it. + ensure_execution_context(); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Cannot bind NCCL communicator: execution context is null"); exec_ctx->setCommunicator(reinterpret_cast(comm_ptr)); this->nccl_initialized = true; @@ -663,39 +676,66 @@ void TRTEngine::release_nccl_comm() { } LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'"); torch::cuda::synchronize(device_info.id); - this->exec_ctx.reset(); - recreate_execution_context(); + invalidate_execution_context(); + // Eagerly rebuild so the engine returns to a "context-live, no NCCL" state + // (callers may immediately query exec_ctx for shape/dtype info post-release). + ensure_execution_context(); this->nccl_initialized = false; LOG_INFO("NCCL communicator released from engine '" << this->name << "'"); } #endif // ENABLE_TRT_NCCL_COLLECTIVES bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { - return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); + return runtime_cfg.is_monolithic_capturable(has_dynamic_inputs, exec_ctx.get(), stream); } void TRTEngine::disable_rtx_native_cudagraphs() { - bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; - runtime_cfg.disable_rtx_native_cudagraphs(name); - if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { - // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx - // so the new strategy takes effect for subsequent enqueueV3 calls. +#ifdef TRT_MAJOR_RTX + if (runtime_cfg.settings().cuda_graph_strategy == "disabled") { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << name << " for the remainder of its lifetime."); + RuntimeSettings new_settings = runtime_cfg.settings(); + new_settings.cuda_graph_strategy = "disabled"; + update_runtime_settings(std::move(new_settings)); +#endif +} + +void TRTEngine::update_runtime_settings(RuntimeSettings new_settings) { + if (!runtime_cfg.set_settings(std::move(new_settings))) { + return; + } + // Lazy: drop the live context, but do NOT eagerly recreate. The next user + // (typically the next ``execute_engine`` call) will lazy-create with the + // new settings via ``ensure_execution_context``. This collapses the + // historical "ctor-create-with-defaults + dispatch-recreate-with-settings" + // pair on the Python ``setup_engine`` cpp branch into a single create. + invalidate_execution_context(); + // Existing recreate sites set runtime_states.context_changed for cudagraph + // re-record; do the same here so a settings flip inside an active CM forces + // the next enqueue to re-record any captured graph. + runtime_states.context_changed = true; +} + +void TRTEngine::ensure_execution_context() { + if (exec_ctx == nullptr) { recreate_execution_context(); } } +void TRTEngine::invalidate_execution_context() noexcept { + exec_ctx.reset(); +} + void TRTEngine::recreate_execution_context() { - // Flush any kernels the previous execution context may have compiled into the - // runtime cache before creating the replacement. The destructor also saves, but - // doing it here guards against losing compiled kernels across profiling toggles, - // allocator changes, or process kills that happen between allocator changes and - // teardown. No-op on standard TensorRT or when no cache path is configured. - runtime_cfg.save_runtime_cache(); const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); + ++num_execution_contexts_created_; } } // namespace runtime diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 47917e9c37..1f1f876826 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -14,6 +14,7 @@ #include "c10/cuda/CUDAStream.h" #include "torch/custom_class.h" +#include "core/runtime/RuntimeSettings.h" #include "core/runtime/TRTEngineProfiler.h" #include "core/runtime/TRTRuntimeConfig.h" #include "core/runtime/TensorRTBindingNames.h" @@ -48,11 +49,7 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple, // requires_native_multidevice - std::tuple, // has_runtime_cfg (gates next three) - std::tuple, // Runtime Cache Path (TRT-RTX) - std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) - std::tuple // CUDA Graph Strategy (TRT-RTX) + std::tuple // requires_native_multidevice >; struct TorchTRTRuntimeStates { @@ -158,7 +155,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); + RuntimeSettings runtime_settings = RuntimeSettings{}); TRTEngine(std::vector serialized_info); @@ -174,7 +171,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); + RuntimeSettings runtime_settings = RuntimeSettings{}); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); @@ -282,22 +279,61 @@ struct TRTEngine : torch::CustomClassHolder { void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); - // Owns the IRuntimeConfig (where supported) and TRT-RTX runtime state. On older TRT - // without IRuntimeConfig (e.g. Jetpack) this just carries strategy values that get - // passed to the legacy createExecutionContext overload. + // Owns the canonical `RuntimeSettings` plus the live `IRuntimeConfig` derived + // from them. The engine forwards `runtime_settings()` and + // `update_runtime_settings()` here -- there is no separate settings field on + // the engine. TRTRuntimeConfig runtime_cfg; + [[nodiscard]] RuntimeSettings const& runtime_settings() const noexcept { + return runtime_cfg.settings(); + } + + // Apply new runtime settings. Fast-paths on equality (via + // `TRTRuntimeConfig::set_settings`). On change, rebuilds the + // `IRuntimeConfig` from the new settings and recreates the execution context. + void update_runtime_settings(RuntimeSettings new_settings); + + // Whether the engine has any input binding with a dynamic dimension. Computed + // once during construction; used by `is_monolithic_capturable`. + bool has_dynamic_inputs = true; + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. bool is_monolithic_capturable(cudaStream_t stream) const; // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when - // an outer stream capture is detected around execute_engine). No-op on non-RTX. + // an outer stream capture is detected around execute_engine). No-op on non-RTX or when + // already disabled. void disable_rtx_native_cudagraphs(); + // Materialize ``exec_ctx`` if it is currently null, using the current settings + // from ``runtime_cfg``. Idempotent: a non-null ``exec_ctx`` is left untouched. + // Called from every site that needs the live context (``execute_engine``, + // ``enable_profiling``, ``bind_nccl_comm``, ``infer_outputs``, etc.). + void ensure_execution_context(); + + // Drop the live ``exec_ctx`` without recreating. The next ``ensure_execution_context`` + // (typically inside the next ``execute_engine`` call) will rebuild from the + // current ``runtime_cfg`` settings. + void invalidate_execution_context() noexcept; + + // Test/observability hook: increments once every time ``runtime_cfg.create_execution_context`` + // is invoked (i.e. an actual TRT createExecutionContext call, which on RTX + // also JIT-compiles the specialized kernel set). Bound on the torchbind class. + // ``noexcept`` is intentionally omitted -- torchbind's ``def`` template is + // not specialized for ``const noexcept`` member functions. + [[nodiscard]] int64_t num_execution_contexts_created() const { + return num_execution_contexts_created_; + } + private: // Single entry point that (re)creates exec_ctx via runtime_cfg.create_execution_context. + // Bumps ``num_execution_contexts_created_``. Callers should normally go through + // ``ensure_execution_context`` for the lazy semantics. void recreate_execution_context(); + + int64_t num_execution_contexts_created_ = 0; }; } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 6f64a95cbd..dafdcb2a2b 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -1,155 +1,105 @@ #include "core/runtime/TRTRuntimeConfig.h" -#include -#include #include #include -#include +#include -#include "core/runtime/runtime.h" +#include "core/runtime/RuntimeSettings.h" #include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace runtime { -// File-local helpers. Kept out of the header because they are only used by this -// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not -// need the enum conversion helpers. namespace { -[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { - switch (s) { - case DynamicShapesKernelStrategy::kLazy: - return "lazy"; - case DynamicShapesKernelStrategy::kEager: - return "eager"; - case DynamicShapesKernelStrategy::kNone: - return "none"; +#ifdef TRT_MAJOR_RTX +[[nodiscard]] nvinfer1::DynamicShapesKernelSpecializationStrategy to_trt_ds_strategy(std::string const& s) { + if (s == "lazy") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kLAZY; } - TORCHTRT_CHECK( - false, - "Unexpected DynamicShapesKernelStrategy value: " - << static_cast>(s)); -} - -[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { - switch (s) { - case CudaGraphStrategyOption::kDisabled: - return "disabled"; - case CudaGraphStrategyOption::kWholeGraphCapture: - return "whole_graph_capture"; + if (s == "eager") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER; + } + if (s == "none") { + return nvinfer1::DynamicShapesKernelSpecializationStrategy::kNONE; } TORCHTRT_CHECK( - false, - "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); -} - -[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( - std::underlying_type_t v) { - TORCHTRT_CHECK( - v >= 0 && v <= 2, - "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); - return static_cast(v); -} - -[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { - TORCHTRT_CHECK( - v >= 0 && v <= 1, - "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); - return static_cast(v); + false, "Invalid dynamic_shapes_kernel_specialization_strategy: \"" << s << "\" (expected lazy | eager | none)"); } -#ifdef TRT_MAJOR_RTX -// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the -// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is -// itself TensorRT-RTX-only and tests reach this path through the member wrappers. -void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { - TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); - if (!std::filesystem::exists(path)) { - LOG_DEBUG("No existing runtime cache at " << path); - return; +[[nodiscard]] nvinfer1::CudaGraphStrategy to_trt_cg_strategy(std::string const& s) { + if (s == "disabled") { + return nvinfer1::CudaGraphStrategy::kDISABLED; } - std::ifstream f(path, std::ios::binary); - std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); - if (buf.empty()) { - return; + if (s == "whole_graph_capture") { + return nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE; } - TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); - LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); + TORCHTRT_CHECK(false, "Invalid cuda_graph_strategy: \"" << s << "\" (expected disabled | whole_graph_capture)"); } +#endif -void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { - TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); - auto host_mem = make_trt(cache->serialize()); - if (!host_mem || host_mem->size() == 0) { - return; - } - std::filesystem::path fs_path(path); - if (fs_path.has_parent_path()) { - std::filesystem::create_directories(fs_path.parent_path()); - } - std::filesystem::path tmp_path = fs_path; - tmp_path += ".tmp"; - { - std::ofstream out(tmp_path, std::ios::binary); - out.write(reinterpret_cast(host_mem->data()), host_mem->size()); +} // namespace + +bool TRTRuntimeConfig::set_settings(RuntimeSettings new_settings) { + if (new_settings == settings_) { + return false; } - std::filesystem::rename(tmp_path, fs_path); - LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); + settings_ = std::move(new_settings); + // Invalidate the live IRuntimeConfig so the next `ensure_initialized` rebuilds + // with the new strategy values + cache attachment. + reset(); + return true; } -#endif // TRT_MAJOR_RTX - -} // namespace void TRTRuntimeConfig::ensure_initialized(TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine) { #ifdef TRT_HAS_IRUNTIME_CONFIG - if (config) { - return; + if (!config) { + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); } - TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); - config = make_trt(cuda_engine->createRuntimeConfig()); - TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); #ifdef TRT_MAJOR_RTX - // Runtime cache -- TRT-RTX only. - if (!runtime_cache_path.empty()) { - runtime_cache = make_trt(config->createRuntimeCache()); - if (runtime_cache.get() == nullptr) { - LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + // Runtime cache: ONLY attach when the caller provided an external + // RuntimeCacheHandle. The Python TRTEngine side creates an implicit + // handle from a path string and passes it in via the handle; without + // an explicit user opt-in we leave the IRuntimeConfig cache-less. + if (settings_.runtime_cache) { + if (!settings_.runtime_cache->cache) { + settings_.runtime_cache->cache = make_trt(config->createRuntimeCache()); + TORCHTRT_CHECK( + settings_.runtime_cache->cache.get() != nullptr, + "Failed to create IRuntimeCache for shared RuntimeCacheHandle"); + } + if (config->setRuntimeCache(*settings_.runtime_cache->cache)) { + LOG_DEBUG("Attached external IRuntimeCache to IRuntimeConfig."); } else { - try { - load_runtime_cache(runtime_cache_path, runtime_cache.get()); - } catch (const std::exception& e) { - LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); - } - if (config->setRuntimeCache(*runtime_cache)) { - LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); - } else { - LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); - runtime_cache.reset(); - } + LOG_WARNING("Failed to attach IRuntimeCache to IRuntimeConfig; cache will be unused."); } } else { - LOG_DEBUG("Runtime cache disabled (no path configured)."); + LOG_DEBUG("Runtime cache disabled (no RuntimeCacheHandle provided)."); } - // Dynamic shapes kernel specialization strategy -- TRT-RTX only. config->setDynamicShapesKernelSpecializationStrategy( - static_cast(dynamic_shapes_kernel_strategy)); - LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + to_trt_ds_strategy(settings_.dynamic_shapes_kernel_specialization_strategy)); + LOG_DEBUG( + "Dynamic shapes kernel specialization strategy set to " + << settings_.dynamic_shapes_kernel_specialization_strategy); - // CUDA graph strategy -- TRT-RTX only. - if (!config->setCudaGraphStrategy( - cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture - ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE - : nvinfer1::CudaGraphStrategy::kDISABLED)) { + if (!config->setCudaGraphStrategy(to_trt_cg_strategy(settings_.cuda_graph_strategy))) { LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); } #endif #endif // TRT_HAS_IRUNTIME_CONFIG } +void TRTRuntimeConfig::reset() { +#ifdef TRT_HAS_IRUNTIME_CONFIG + config.reset(); +#endif +} + std::shared_ptr TRTRuntimeConfig::create_execution_context( nvinfer1::ICudaEngine* cuda_engine, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy) { @@ -163,92 +113,41 @@ std::shared_ptr TRTRuntimeConfig::create_execution_ #endif } -bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const noexcept { #ifdef TRT_MAJOR_RTX // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled - // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the - // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal - // capture would collide with it. - return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; + // strategy is set, or when subgraph cudagraphs are enabled globally. In both + // cases the caller should skip its manual at::cuda::CUDAGraph wrapper. + return settings_.cuda_graph_strategy != "disabled" || cudagraphs_enabled; #else return false; #endif } -void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { -#ifdef TRT_MAJOR_RTX - if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { - return; - } - LOG_WARNING( - "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " - << engine_name << " for the remainder of its lifetime."); - // Persist any kernels the engine-internal capture has compiled so far; the outer - // capture will run without them otherwise, and we want future reloads to reuse them. - save_runtime_cache(); - cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; - if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { - LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); - } - rtx_native_cudagraphs_disabled = true; -#endif -} - bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED bool has_dynamic_inputs, TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, - TORCHTRT_UNUSED cudaStream_t stream) const { + TORCHTRT_UNUSED cudaStream_t stream) const noexcept { #ifdef TRT_MAJOR_RTX TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); if (!exec_ctx->isStreamCapturable(stream)) { return false; } - // "lazy" kernel specialization only swaps specialized kernels mid-run when an input - // has a dynamic dimension; for static-shape engines the kernels are fixed at setup and - // the captured graph stays valid. Mirrors the Python `_is_monolithic_capturable` check. - return !(dynamic_shapes_kernel_strategy == DynamicShapesKernelStrategy::kLazy && has_dynamic_inputs); + // "lazy" kernel specialization only swaps specialized kernels mid-run when an + // input has a dynamic dimension; for static-shape engines the kernels are fixed + // at setup and the captured graph stays valid. Mirrors the Python check. + return !(settings_.dynamic_shapes_kernel_specialization_strategy == "lazy" && has_dynamic_inputs); #else return true; #endif } -void TRTRuntimeConfig::save_runtime_cache() noexcept { -#ifdef TRT_MAJOR_RTX - if (!runtime_cache || runtime_cache_path.empty()) { - return; - } - try { - save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); - } catch (const std::exception& e) { - LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); - } catch (...) { - LOG_WARNING("Failed to save runtime cache (unknown exception)."); - } -#endif -} - -std::string TRTRuntimeConfig::to_str() const { - std::ostringstream os; - os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; - os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; - os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; - return os.str(); -} - -TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info) { - TRTRuntimeConfig cfg; - if (info[HAS_RUNTIME_CFG_IDX] == "1") { - cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; - cfg.dynamic_shapes_kernel_strategy = - to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); - cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); - } - return cfg; -} - std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { - os << "Runtime cfg {" << std::endl; - os << cfg.to_str(); - os << "}" << std::endl; + os << "TRTRuntimeConfig{settings=" << cfg.settings().to_str(); +#ifdef TRT_HAS_IRUNTIME_CONFIG + os << ", config=" << (cfg.config ? "live" : "null"); +#endif + os << "}"; return os; } diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 6e7b8bc6ab..5c9c3eaa44 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -1,99 +1,77 @@ #pragma once #include -#include #include #include #include -#include -#include +#include #include "NvInfer.h" +#include "core/runtime/RuntimeSettings.h" namespace torch_tensorrt { namespace core { namespace runtime { -// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. -enum class DynamicShapesKernelStrategy : int32_t { - kLazy = 0, - kEager = 1, - kNone = 2, -}; - -// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. -enum class CudaGraphStrategyOption : int32_t { - kDisabled = 0, - kWholeGraphCapture = 1, -}; - -// Encapsulates the IRuntimeConfig and TRT-RTX runtime state for a TRTEngine. -// IRuntimeConfig and runtime-cache `#ifdef`s are confined to this TU; serialization- -// index plumbing keeps its own RTX gates elsewhere. +// Owns the canonical `RuntimeSettings` for an engine, the live `IRuntimeConfig` +// derived from those settings (where supported), and translates strategy +// strings into TRT calls. All `TRT_HAS_IRUNTIME_CONFIG` / `TRT_MAJOR_RTX` +// branching is confined to this TU. +// +// `TRTEngine` holds a `TRTRuntimeConfig` member; the engine itself does not +// store a separate `RuntimeSettings`. `engine.runtime_settings()` forwards +// here. struct TRTRuntimeConfig { - // Settings - typically populated from engine deserialization before `ensure_initialized`. - std::string runtime_cache_path = ""; - DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; - CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; - - // One-shot: set to true once an outer stream capture has been detected and the - // engine-internal CUDA graph strategy has been disabled for the remainder of the - // owning engine's lifetime. - bool rtx_native_cudagraphs_disabled = false; - - bool has_dynamic_inputs = true; - - // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized` - // and is unavailable on TensorRT versions older than 10.11 (e.g. Jetpack). -#ifdef TRT_HAS_IRUNTIME_CONFIG - std::shared_ptr config; -#endif -#ifdef TRT_MAJOR_RTX - std::shared_ptr runtime_cache; -#endif - - // Lazily construct the IRuntimeConfig and apply RTX-specific settings. Idempotent. - // No-op on builds without IRuntimeConfig (e.g. Jetpack). + TRTRuntimeConfig() = default; + explicit TRTRuntimeConfig(RuntimeSettings settings) : settings_(std::move(settings)) {} + + // Canonical user-facing runtime settings for this engine. Mutated only via + // `set_settings` so the live `IRuntimeConfig` stays in sync. + [[nodiscard]] RuntimeSettings const& settings() const noexcept { + return settings_; + } + + // Returns true iff `new_settings` differs from the current settings (i.e. + // the caller should recreate the `IExecutionContext`). On change the live + // `IRuntimeConfig` is invalidated; the next `ensure_initialized` rebuilds. + bool set_settings(RuntimeSettings new_settings); + + // (Re)build the `IRuntimeConfig` from `settings_`. Idempotent if the previous + // build was against identical settings. void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); - // Lazy-initialize the IRuntimeConfig if needed and create an IExecutionContext that - // honors `allocation_strategy`. Selects the right `createExecutionContext` overload - // (IRuntimeConfig* vs ExecutionContextAllocationStrategy) so callers stay free of - // any TRT_HAS_IRUNTIME_CONFIG branching. + // Force the next `ensure_initialized` to rebuild from scratch. + void reset(); + + // Lazy-init + create a fresh `IExecutionContext` honoring `allocation_strategy`. + // Picks the right `createExecutionContext` overload (IRuntimeConfig* vs + // ExecutionContextAllocationStrategy) so callers stay free of any + // `TRT_HAS_IRUNTIME_CONFIG` branching. [[nodiscard]] std::shared_ptr create_execution_context( nvinfer1::ICudaEngine* cuda_engine, nvinfer1::ExecutionContextAllocationStrategy allocation_strategy); - // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the - // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always - // false on non-RTX builds. - [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; - - // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream - // capture is detected around execute_engine, so the outer capture can contain the - // kernel launches directly. Saves the runtime cache before recreating the context so - // compiled kernels from the present run are preserved for future reloads. - void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; + // Returns true if TRT-RTX owns capture/replay for the current settings -- + // caller should then bypass its own `at::cuda::CUDAGraph` capture around + // enqueueV3. Always false on non-RTX builds. + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const noexcept; - // Whether the execution context is safe to include in an outer monolithic capture. - // Non-RTX builds always return true. - [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + // Returns true iff the execution context can be safely included in an outer + // monolithic capture. Non-RTX builds always return true. + [[nodiscard]] bool is_monolithic_capturable( + bool has_dynamic_inputs, + nvinfer1::IExecutionContext* exec_ctx, + cudaStream_t stream) const noexcept; - // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a - // destructor. The underlying file I/O is performed by free functions declared below - // (non-noexcept, exception-leaky for easier testing); this member wraps them and - // swallows any exceptions. - void save_runtime_cache() noexcept; +#ifdef TRT_HAS_IRUNTIME_CONFIG + // Lazy-constructed live config. `nullptr` until first `ensure_initialized`. + std::shared_ptr config; +#endif - // Returns a human-readable summary of the runtime config. - [[nodiscard]] std::string to_str() const; + private: + RuntimeSettings settings_; }; -// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the -// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized -// struct. -[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); - std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 80936951ef..83667a45af 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -212,6 +212,11 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + // Materialize the IExecutionContext on the first execute (under the lazy-create + // policy the ctor and ``update_runtime_settings`` no longer eagerly build one). + // Idempotent: a non-null exec_ctx is left untouched. + compiled_engine->ensure_execution_context(); + // All inputs are expected to be on CUDA. Warn and move any that are not. for (auto& inp : inputs) { if (inp.defined() && !inp.is_cuda()) { diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 44d1b314ca..7afada2c21 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -1,6 +1,7 @@ #include #include "core/runtime/Platform.h" +#include "core/runtime/RuntimeSettings.h" #include "core/runtime/runtime.h" #include "core/util/macros.h" @@ -13,6 +14,21 @@ namespace core { namespace runtime { namespace { + +// Register `RuntimeCacheHandle` as a torchbind class so Python can pass the same +// underlying `IRuntimeCache` to both Python and C++ engine backends. File I/O on +// the handle is the Python side's responsibility; the C++ struct only holds the +// shared_ptr and an informational path string. The method bodies (and the +// `#ifdef TRT_MAJOR_RTX` they entail) live in RuntimeSettings.cpp -- this file +// is registration-only. +static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = + torch::class_("tensorrt", "RuntimeCacheHandle") + .def(torch::init()) + .def_readwrite("path", &RuntimeCacheHandle::path) + .def("serialize", &RuntimeCacheHandle::serialize) + .def("deserialize", &RuntimeCacheHandle::deserialize) + .def("has_cache", &RuntimeCacheHandle::has_cache); + // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { // auto input_vec = inputs.vec(); @@ -40,6 +56,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) .def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned) + .def("num_execution_contexts_created", &TRTEngine::num_execution_contexts_created) .def( "use_dynamically_allocated_resources", [](const c10::intrusive_ptr& self, bool dynamic) -> void { @@ -47,6 +64,20 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic : TRTEngine::ResourceAllocationStrategy::kStatic); }) + .def( + "update_runtime_settings", + [](const c10::intrusive_ptr& self, + std::string const& dynamic_shapes_kernel_specialization_strategy, + std::string const& cuda_graph_strategy, + c10::optional> runtime_cache) -> void { + // `c10::optional` lets TorchBind accept Python `None` here. We + // translate to a (possibly null) intrusive_ptr inside the struct. + RuntimeSettings rs; + rs.dynamic_shapes_kernel_specialization_strategy = dynamic_shapes_kernel_specialization_strategy; + rs.cuda_graph_strategy = cuda_graph_strategy; + rs.runtime_cache = runtime_cache.has_value() ? std::move(*runtime_cache) : nullptr; + self->update_runtime_settings(std::move(rs)); + }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("pre_allocated_outputs", &TRTEngine::pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) @@ -147,10 +178,6 @@ TORCH_LIBRARY(tensorrt, m) { return false; #endif }); - m.def("HAS_RUNTIME_CFG_IDX", []() -> int64_t { return HAS_RUNTIME_CFG_IDX; }); - m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); - m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); - m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 2cbe73d6da..a87bd2ca2a 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -41,11 +41,6 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, - // HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. - HAS_RUNTIME_CFG_IDX, - RUNTIME_CACHE_PATH_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, - CUDA_GRAPH_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -62,10 +57,6 @@ inline constexpr std::array kSerializedInfoIndex "REQUIRES_OUTPUT_ALLOCATOR_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", - "HAS_RUNTIME_CFG_IDX", - "RUNTIME_CACHE_PATH_IDX", - "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", - "CUDA_GRAPH_STRATEGY_IDX", }}; // For adding new serialized info indices, update above and update /dynamo/runtime/_serialized_engine_layout.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 8671dd5860..0023a0a420 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -53,6 +53,7 @@ to_torch_device, to_torch_tensorrt_device, ) +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -89,9 +90,6 @@ def cross_compile_for_windows( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -112,6 +110,7 @@ def cross_compile_for_windows( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, + runtime_settings: Optional[RuntimeSettings] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -170,9 +169,6 @@ def cross_compile_for_windows( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -318,9 +314,6 @@ def cross_compile_for_windows( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -354,12 +347,6 @@ def cross_compile_for_windows( f"arg: {key} is not supported for cross compilation for windows feature, hence it is disabled." ) - if "runtime_cache_path" in compilation_options: - compilation_options.pop("runtime_cache_path") - logger.warning( - "runtime_cache_path is a JIT-time API and is not applicable to cross compilation for windows. Ignoring." - ) - settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -398,6 +385,7 @@ def cross_compile_for_windows( trt_arg_inputs, trt_kwarg_inputs, settings, + runtime_settings=runtime_settings, ) return trt_gm @@ -433,9 +421,6 @@ def compile( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -469,6 +454,7 @@ def compile( dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, + runtime_settings: Optional[RuntimeSettings] = None, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -529,9 +515,6 @@ def compile( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -709,9 +692,6 @@ def compile( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -775,7 +755,12 @@ def compile( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) trt_gm = compile_module( - gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache + gm, + trt_arg_inputs, + trt_kwarg_inputs, + settings, + engine_cache, + runtime_settings=runtime_settings, ) return trt_gm @@ -889,6 +874,7 @@ def compile_module( engine_cache: Optional[BaseEngineCache] = None, *, _debugger_config: Optional[DebuggerConfig] = None, + runtime_settings: Optional[RuntimeSettings] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -1126,6 +1112,7 @@ def preserve_module_specs( settings=settings, name=name, engine_cache=engine_cache, + runtime_settings=runtime_settings, ) trt_modules[name] = trt_module @@ -1230,9 +1217,6 @@ def convert_exported_program_to_serialized_trt_engine( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, - runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1307,9 +1291,6 @@ def convert_exported_program_to_serialized_trt_engine( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1464,9 +1445,6 @@ def convert_exported_program_to_serialized_trt_engine( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, - "runtime_cache_path": runtime_cache_path, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, - "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1484,11 +1462,6 @@ def convert_exported_program_to_serialized_trt_engine( "attn_bias_is_causal": attn_bias_is_causal, "use_python_runtime": use_python_runtime, } - if "runtime_cache_path" in compilation_options: - compilation_options.pop("runtime_cache_path") - logger.warning( - "runtime_cache_path is a JIT-time API and is not applicable to serialized engine export. Ignoring." - ) settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 00bc39bce5..e163e830a1 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -38,9 +38,6 @@ TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) -RUNTIME_CACHE_PATH = os.path.join( - tempfile.gettempdir(), "torch_tensorrt_engine_cache", "runtime_cache.bin" -) LAZY_ENGINE_INIT = False CACHE_BUILT_ENGINES = False REUSE_CACHED_ENGINES = False @@ -69,8 +66,6 @@ DYNAMICALLY_ALLOCATE_RESOURCES = False DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True -DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" -CUDA_GRAPH_STRATEGY = "disabled" USE_PYTHON_RUNTIME = False if platform.system() == "Linux": @@ -86,6 +81,19 @@ tempfile.gettempdir(), f"torch_tensorrt_{current_user}/debug_logs" ) +# --------------------------------------------------------------------------- +# Runtime-only knobs (see torch_tensorrt.runtime.RuntimeSettings). Defaults +# live here to mirror compilation-settings convention; the dataclass imports +# from this module. +# --------------------------------------------------------------------------- +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" +# Default to a per-user temp file (mirrors ENGINE_CACHE_DIR). Users can override +# via ``RuntimeSettings(runtime_cache="/different/path")`` or a runtime CM. +RUNTIME_CACHE_PATH = os.path.join( + tempfile.gettempdir(), f"torch_tensorrt_{current_user}/runtime_cache.bin" +) + def default_device() -> Device: return Device(gpu_id=torch.cuda.current_device()) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 694b1a7000..9227ad0b81 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,14 +17,12 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, - CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, DYNAMICALLY_ALLOCATE_RESOURCES, ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, @@ -45,7 +43,6 @@ REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, - RUNTIME_CACHE_PATH, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, TILING_OPTIMIZATION_LEVEL, @@ -94,9 +91,6 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. - dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". - cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. @@ -148,11 +142,6 @@ class CompilationSettings: dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH - runtime_cache_path: str = RUNTIME_CACHE_PATH - dynamic_shapes_kernel_specialization_strategy: str = ( - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY - ) - cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d712d7f150..fc886656a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -25,6 +25,7 @@ release_host_and_device_memory, ) from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -333,6 +334,7 @@ def convert_module( settings: CompilationSettings = CompilationSettings(), name: str = "", engine_cache: Optional[BaseEngineCache] = None, + runtime_settings: Optional[RuntimeSettings] = None, ) -> TorchTensorRTModule: """Convert an FX module to a TRT module Args: @@ -341,6 +343,8 @@ def convert_module( settings: Compilation settings name: TRT engine name engine_cache: Engine cache instance + runtime_settings: Optional runtime-mode-control overrides threaded to the + built ``TRTEngine``. Not part of ``CompilationSettings``; not serialized. Returns: TorchTensorRTModule """ @@ -379,4 +383,5 @@ def convert_module( requires_output_allocator=serialized_interpreter_result.requires_output_allocator, requires_native_multidevice=serialized_interpreter_result.requires_native_multidevice, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, + runtime_settings=runtime_settings, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 537444a0b9..7db59aaea6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -145,10 +145,7 @@ def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( TorchTensorRTModule, ) - from torch_tensorrt.dynamo.runtime._TRTEngine import ( - TRTEngine, - _get_cuda_graph_strategy, - ) + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine for name, mod in self.compiled_module.named_modules(): if not ( @@ -170,10 +167,8 @@ def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: # Disable RTX-native CUDA graphs on this engine so they don't # interfere with the outer monolithic capture. if engine._rtx_native_cudagraphs: - engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - "disabled" - ) - engine.context = engine._create_execution_context() + disabled = engine.runtime_settings.merge(cuda_graph_strategy="disabled") + engine.update_runtime_settings(disabled) engine._rtx_native_cudagraphs = False logger.info( f"Disabled RTX-native CUDA graphs for '{name}' " diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index e8496f2d3d..62d787d47a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -11,12 +11,20 @@ import base64 import copy import logging -import os import pickle import tempfile from contextlib import nullcontext from types import SimpleNamespace -from typing import Any, ContextManager, Dict, List, Optional, Sequence, Tuple, cast +from typing import ( + Any, + ContextManager, + Dict, + List, + Optional, + Sequence, + Tuple, + cast, +) import torch import torch.distributed as dist @@ -47,6 +55,7 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._runtime_config import RuntimeSettings, TRTRuntimeConfig from torch_tensorrt.runtime._utils import ( _is_switch_required, _select_rt_device, @@ -59,23 +68,6 @@ logger = logging.getLogger(__name__) -def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: - """Map strategy string to TRT enum. Only meaningful on TensorRT-RTX builds.""" - return { - "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, - "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, - "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, - }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) - - -def _get_cuda_graph_strategy(strategy_str: str) -> Any: - """Map strategy string to TRT CudaGraphStrategy enum. Only meaningful on RTX.""" - return { - "disabled": trt.CudaGraphStrategy.DISABLED, - "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, - }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) - - # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -218,6 +210,7 @@ def __init__( serialized_info: SerializedTensorRTEngineFmt, *, profile_execution: bool = False, + runtime_settings: Optional[RuntimeSettings] = None, ) -> None: self._profile_execution = profile_execution self.profile_path_prefix = tempfile.gettempdir() @@ -239,10 +232,6 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # Initialized to ``None`` here so the destructor can safely save the - # cache even if ``_setup_engine`` never runs. - self.runtime_config: Any = None - self.runtime_cache: Any = None # When true, ``_execute_standard`` must skip manual torch.cuda.CUDAGraph # capture because TRT-RTX handles it internally. self._rtx_native_cudagraphs: bool = False @@ -250,9 +239,36 @@ def __init__( # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None + # Owns RuntimeSettings + the live trt.IRuntimeConfig + the + # engine-implicit RuntimeCacheHandle. Hides all RTX feature gates. + self._trt_runtime_config: TRTRuntimeConfig = TRTRuntimeConfig( + runtime_settings or RuntimeSettings() + ) + self._load_serialized_info(serialized_info) self._setup_engine() + # --- public property forwards --- + + @property + def runtime_settings(self) -> RuntimeSettings: + """The current ``RuntimeSettings`` for this engine. + + Backed by ``self._trt_runtime_config.settings``; mutations go through + :meth:`update_runtime_settings`. + """ + return self._trt_runtime_config.settings + + @property + def runtime_config(self) -> Any: + """The live ``trt.IRuntimeConfig`` (or ``None`` on non-RTX builds).""" + return self._trt_runtime_config._live + + @property + def _implicit_cache_handle(self) -> Any: + """The engine-implicit ``RuntimeCacheHandle`` if a path-string compile-time hint was given.""" + return self._trt_runtime_config.implicit_cache_handle + def __del__(self) -> None: self.close() @@ -305,14 +321,14 @@ def __setstate__(self, state: Any) -> None: torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - # See ``__init__`` for the rationale: pre-init these so a destructor - # firing on a partially-loaded engine never trips an ``AttributeError``. - self.runtime_config = None - self.runtime_cache = None self._rtx_native_cudagraphs = False # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). self._nccl_comm = None + # RuntimeSettings are NOT serialized -- restore defaults. Callers + # who want runtime-mode overrides must reapply them post-load via + # ``compiled.set_runtime_settings(...)`` or a runtime CM. + self._trt_runtime_config = TRTRuntimeConfig(RuntimeSettings()) serialized_info = list(state[0]) engine_field = serialized_info[ENGINE_IDX] @@ -401,24 +417,40 @@ def get_serialized_metadata(self) -> str: return self.serialized_metadata def close(self) -> None: - """Persist the runtime cache and release CUDA graph resources.""" - self._save_runtime_cache() + """Release CUDA graph resources. + + Implicit runtime cache persistence is now driven by the + :class:`~torch_tensorrt.runtime._runtime_cache.RuntimeCacheHandle`'s + own ``__del__`` (with ``autosave_on_del=True``), so no explicit save + is needed here. + """ self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: - if ENABLED_FEATURES.tensorrt_rtx: - assert self.runtime_config is not None - context = self.cuda_engine.create_execution_context(self.runtime_config) - else: - strategy = ( - trt.ExecutionContextAllocationStrategy.USER_MANAGED - if self.resource_allocation_strategy - else trt.ExecutionContextAllocationStrategy.STATIC - ) - context = self.cuda_engine.create_execution_context(strategy) + alloc_strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) + context = self._trt_runtime_config.create_execution_context( + self.cuda_engine, alloc_strategy + ) assert context is not None, "Failed to create execution context" + # Mirrors ``TRTEngine::num_execution_contexts_created`` on the C++ side; + # used by tests to assert single createExecutionContext per engine setup. + self._num_execution_contexts_created = ( + getattr(self, "_num_execution_contexts_created", 0) + 1 + ) return context + def num_execution_contexts_created(self) -> int: + """Number of TRT ``createExecutionContext`` invocations on this engine. + + Each call (re)JITs the specialized kernel set on RTX, so this is the + canonical counter for the setup-cost regression test. + """ + return getattr(self, "_num_execution_contexts_created", 0) + def _setup_engine(self) -> None: multi_gpu_device_check() self.runtime = trt.Runtime(TRT_LOGGER) @@ -430,16 +462,14 @@ def _setup_engine(self) -> None: logger.debug(f"Weight streaming budget set to {budget_bytes}B") self.cuda_engine.weight_streaming_budget_v2 = budget_bytes - # On TensorRT-RTX, build the IRuntimeConfig (runtime cache, - # dynamic-shape kernel specialization strategy, and CUDA graph - # strategy) up front so the one-and-only execution context picks it up. - if ENABLED_FEATURES.tensorrt_rtx: - self._setup_runtime_config() - self._rtx_native_cudagraphs = ( - self.settings.cuda_graph_strategy != "disabled" - ) - + # The TRTRuntimeConfig shim builds the live IRuntimeConfig (with cache + # + strategies) inside ``create_execution_context`` on RTX; no-op + # otherwise. Track the cudagraph-disabled-or-not state for the + # ``_execute_standard`` path to consult. self.context = self._create_execution_context() + self._rtx_native_cudagraphs = ENABLED_FEATURES.tensorrt_rtx and ( + self.runtime_settings.cuda_graph_strategy != "disabled" + ) if self._has_nccl_ops: from torch_tensorrt.distributed._nccl_utils import ( @@ -503,116 +533,43 @@ def _setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() - # --- TensorRT-RTX --- + # --- TensorRT-RTX runtime-config delegation --- - def _setup_runtime_config(self) -> None: - """Build an ``IRuntimeConfig`` with runtime cache and dynamic-shape strategy. + def update_runtime_settings(self, new_settings: RuntimeSettings) -> None: + """Apply new ``RuntimeSettings`` to this engine. - The runtime cache stores JIT compilation results so kernel/graph - compilation is not repeated across inference runs. The dynamic-shape - kernel specialization strategy controls how the JIT compiles - shape-specialized kernels (``lazy``, ``eager``, or ``none``). + Fast-paths on equality via ``TRTRuntimeConfig.set_settings``. On + change, the prior implicit cache (if any) is saved, the live + ``IRuntimeConfig`` is invalidated, and a fresh ``IExecutionContext`` + is created. """ - self.runtime_config = self.cuda_engine.create_runtime_config() - alloc_strategy = ( - trt.ExecutionContextAllocationStrategy.USER_MANAGED - if self.resource_allocation_strategy - else trt.ExecutionContextAllocationStrategy.STATIC - ) - self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) - self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( - _get_dynamic_shapes_kernel_strategy( - self.settings.dynamic_shapes_kernel_specialization_strategy - ) - ) - logger.info( - "Dynamic shapes kernel specialization strategy: " - f"{self.settings.dynamic_shapes_kernel_specialization_strategy}" - ) - self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - self.settings.cuda_graph_strategy - ) - logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") - self.runtime_cache = self.runtime_config.create_runtime_cache() - self._load_runtime_cache() - self.runtime_config.set_runtime_cache(self.runtime_cache) - logger.info("TensorRT-RTX runtime cache configured") - - def _load_runtime_cache(self) -> None: - """Load runtime cache from disk if it exists (with a shared file lock).""" - if self.runtime_cache is None: - return - cache_path = self.settings.runtime_cache_path - if not os.path.isfile(cache_path): - logger.debug(f"No existing runtime cache at {cache_path}") - return - try: - from filelock import FileLock - - lock = FileLock(cache_path + ".lock") - with lock.acquire(timeout=10): - with open(cache_path, "rb") as f: - data = f.read() - if data: - self.runtime_cache.deserialize(data) - logger.info(f"Loaded runtime cache from {cache_path}") - except Exception as e: - logger.warning(f"Failed to load runtime cache: {e}") - - def _save_runtime_cache(self) -> None: - """Save runtime cache to disk (with an exclusive file lock).""" - if self.runtime_cache is None: + if not self._trt_runtime_config.set_settings(new_settings): return - try: - host_mem = self.runtime_cache.serialize() - if host_mem is None: - return - cache_path = self.settings.runtime_cache_path - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - - from filelock import FileLock - - lock = FileLock(cache_path + ".lock") - with lock.acquire(timeout=10): - with open(cache_path, "wb") as f: - f.write(memoryview(host_mem)) - logger.info(f"Saved runtime cache to {cache_path}") - except Exception as e: - logger.warning(f"Failed to save runtime cache: {e}") + self.context = self._create_execution_context() + self._rtx_native_cudagraphs = ENABLED_FEATURES.tensorrt_rtx and ( + self.runtime_settings.cuda_graph_strategy != "disabled" + ) + self.runtime_states.context_changed = True def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: - """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. - - On RTX, unsafe when the TRT-RTX context is not stream-capturable, or - when ``"lazy"`` kernel specialization can still fire (dynamic inputs). - """ - if not ENABLED_FEATURES.tensorrt_rtx: - return True + """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe.""" has_dynamic_input = any(DYNAMIC_DIM in shape for shape in self.input_shapes) - not_capturable = ( - not self.context.is_stream_capturable(stream.cuda_stream), - ( - self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy" - and has_dynamic_input - ), + return self._trt_runtime_config.is_monolithic_capturable( + has_dynamic_input, self.context, stream ) - return not any(not_capturable) def _enable_rtx_native_cudagraphs(self) -> None: """Switch this engine to TRT-RTX native CUDA graphs. - Sets the runtime config's ``cuda_graph_strategy`` to - ``WHOLE_GRAPH_CAPTURE`` and rebuilds the execution context so it - picks up the new strategy. No-op on non-RTX or when the runtime - config is not present. + Mutates settings via ``update_runtime_settings`` so the prior cache is + saved + a fresh context is created uniformly. No-op on non-RTX builds. """ - if self.runtime_config is None: + if not ENABLED_FEATURES.tensorrt_rtx: return - self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( - "whole_graph_capture" + new_settings = self.runtime_settings.merge( + cuda_graph_strategy="whole_graph_capture" ) - self.context = self._create_execution_context() - self._rtx_native_cudagraphs = True + self.update_runtime_settings(new_settings) logger.info("Switched to TRT-RTX native CUDA graphs") # --- distributed / NCCL --- diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 77253f3deb..ebb4b833e6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -14,11 +14,8 @@ from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( ABI_TARGET_IDX, ABI_VERSION, - CUDA_GRAPH_STRATEGY_IDX, DEVICE_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, ENGINE_IDX, - HAS_RUNTIME_CFG_IDX, HW_COMPATIBLE_IDX, INPUT_BINDING_NAMES_IDX, NAME_IDX, @@ -26,7 +23,6 @@ REQUIRES_NATIVE_MULTIDEVICE_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, - RUNTIME_CACHE_PATH_IDX, SERIALIZATION_LEN, SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, @@ -34,6 +30,7 @@ serialize_binding_names, serialize_device_info, ) +from torch_tensorrt.runtime._runtime_config import RuntimeSettings logger = logging.getLogger(__name__) @@ -44,16 +41,6 @@ List[str], ] -_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { - "lazy": 0, - "eager": 1, - "none": 2, -} -_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { - "disabled": 0, - "whole_graph_capture": 1, -} - class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """``nn.Module`` that runs a TensorRT engine inside PyTorch. @@ -79,6 +66,7 @@ def __init__( requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, + runtime_settings: Optional[RuntimeSettings] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -146,25 +134,10 @@ def __init__( self.execute_engine_op: Any = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources - self.runtime_cache_path = settings.runtime_cache_path - self.dynamic_shapes_kernel_specialization_strategy = ( - settings.dynamic_shapes_kernel_specialization_strategy - ) - if ( - self.dynamic_shapes_kernel_specialization_strategy - not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP - ): - raise ValueError( - f"Invalid dynamic_shapes_kernel_specialization_strategy " - f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " - f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" - ) - self.cuda_graph_strategy = settings.cuda_graph_strategy - if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: - raise ValueError( - f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " - f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" - ) + + # Per-engine runtime mode controls. Defaults to ``RuntimeSettings()`` if + # not supplied; the dataclass validates at ``__post_init__``. + self._runtime_settings: RuntimeSettings = runtime_settings or RuntimeSettings() self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice self.target_platform = ( @@ -261,18 +234,9 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = str( int(self.requires_native_multidevice) ) - # rank/world_size are runtime facts; queried from ProcessGroup at execution time - engine_info[HAS_RUNTIME_CFG_IDX] = "1" if ENABLED_FEATURES.tensorrt_rtx else "0" - engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" - engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( - _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ - self.dynamic_shapes_kernel_specialization_strategy - ] - ) - engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( - _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] - ) - + # rank/world_size are runtime facts; queried from ProcessGroup at execution time. + # RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory + # init values, not part of the engine's identity (see pytorch/TensorRT#4310). return engine_info def get_streamable_device_memory_budget(self) -> Any: @@ -306,6 +270,117 @@ def use_dynamically_allocated_resources( self.dynamically_allocate_resources ) + # --- runtime-settings dispatch ---------------------------------------- + + @property + def runtime_settings(self) -> RuntimeSettings: + """The current ``RuntimeSettings`` on this module (and its engine). + + This is the snapshot the ``runtime_config`` CM reads at ``__enter__`` + and restores at ``__exit__``. + """ + return self._runtime_settings + + def set_runtime_settings(self, rs: RuntimeSettings) -> None: + """Apply ``RuntimeSettings`` to all TRT engines under this module. + + Walks ``named_modules()`` so calling on a wrapper / parent + ``nn.Module`` propagates to every contained + ``TorchTensorRTModule``. Dispatches to the Python ``TRTEngine`` or + the C++ ``torch.classes.tensorrt.Engine`` per submodule's backend. + """ + for _, mod in self.named_modules(): + if isinstance(mod, TorchTensorRTModule) and mod.engine is not None: + mod._dispatch_runtime_settings_to_engine(rs) + mod._runtime_settings = rs + + def _dispatch_runtime_settings_to_engine(self, rs: RuntimeSettings) -> None: + """Backend-aware dispatch of ``update_runtime_settings(rs)`` to ``self.engine``.""" + if self.engine is None: + return + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + + if isinstance(self.engine, TRTEngine): + # Python runtime: dataclass passes straight through. + self.engine.update_runtime_settings(rs) + return + + # C++ torchbind engine: re-materialize the Python-side implicit cache + # wrapper before dispatch (saving the prior wrapper synchronously, to + # mirror ``TRTRuntimeConfig.set_settings`` on the Python runtime). + rs_for_dispatch, needs_load = self._materialize_cpp_implicit_handle(rs) + + from torch_tensorrt.runtime._runtime_cache import _to_torchbind_handle + + cache_arg = _to_torchbind_handle(rs_for_dispatch.runtime_cache) + self.engine.update_runtime_settings( + rs_for_dispatch.dynamic_shapes_kernel_specialization_strategy, + rs_for_dispatch.cuda_graph_strategy, + cache_arg, + ) + # The C++ engine's `update_runtime_settings` materializes the + # IRuntimeCache inside the torchbind handle. Load the on-disk bytes + # (filelocked) so the new cache starts warm. + if needs_load and self._implicit_cache_handle is not None: + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.debug(f"Failed to load implicit runtime cache: {e}") + + def _materialize_cpp_implicit_handle( + self, rs: RuntimeSettings + ) -> Tuple[RuntimeSettings, bool]: + """Mirror of ``TRTRuntimeConfig._apply_settings`` for the cpp engine path. + + When ``rs.runtime_cache`` is a path string, builds a torchbind + ``RuntimeCacheHandle`` + Python wrapper and stashes the wrapper on + ``self._implicit_cache_handle`` so its ``__del__`` saves the cache + on engine destruction. Synchronously saves any prior implicit wrapper + before replacing it -- matches the explicit-save semantic in + :py:meth:`TRTRuntimeConfig.set_settings` so a settings swap doesn't + silently drop kernels JIT-compiled into the prior cache. + + Returns ``(rs_for_dispatch, needs_load)``: ``rs_for_dispatch`` has the + path string replaced with the torchbind handle (so dispatch passes the + same handle through to the C++ engine); ``needs_load`` indicates the + caller should call ``self._implicit_cache_handle.load()`` after the + engine has materialized its IRuntimeCache via the handle. + """ + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + old = self._implicit_cache_handle # type: ignore[has-type] + rc = rs.runtime_cache + if isinstance(rc, str) and rc: + # No-op fast path: if the prior wrapper is already pointed at the + # same disk path and still holds its torchbind sibling, reuse it. + # Without this the cpp ``set_settings`` sees a *different* + # ``runtime_cache.get()`` pointer every call and invalidates the + # execution context even when the user passed identical settings. + if ( + old is not None + and getattr(old, "path", None) == rc + and old._torchbind is not None + ): + rs_for_dispatch = rs.merge(runtime_cache=old._torchbind) + return rs_for_dispatch, False + tb = torch.classes.tensorrt.RuntimeCacheHandle(rc) + new = RuntimeCacheHandle(path=rc, autosave_on_del=True, torchbind_handle=tb) + self._implicit_cache_handle = new + rs_for_dispatch = rs.merge(runtime_cache=tb) + needs_load = True + else: + self._implicit_cache_handle = None + rs_for_dispatch = rs + needs_load = False + if old is not None and old is not self._implicit_cache_handle: + try: + old.save() + except Exception as e: + logger.warning( + f"Failed to save prior implicit runtime cache on swap: {e}" + ) + return rs_for_dispatch, needs_load + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. @@ -324,11 +399,24 @@ def setup_engine(self) -> None: self.engine = TRTEngine( self._pack_engine_info(), profile_execution=self.profiling_enabled, + runtime_settings=self._runtime_settings, ) self.execute_engine_op = torch.ops.tensorrt.execute_engine_python else: self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) self.execute_engine_op = torch.ops.tensorrt.execute_engine + # ``_dispatch_runtime_settings_to_engine`` (cpp branch) handles the + # str-path → torchbind handle + Python wrapper materialization, the + # dispatch to the C++ engine, AND the on-disk load. Initialize the + # attribute first so the helper's access is well-defined. + self._implicit_cache_handle: Any = None + self._dispatch_runtime_settings_to_engine(self._runtime_settings) + # Reflect the substituted handle back onto self._runtime_settings + # so subsequent reads (CM snapshot/restore) carry the same object. + if self._implicit_cache_handle is not None: + self._runtime_settings = self._runtime_settings.merge( + runtime_cache=self._implicit_cache_handle._torchbind + ) # requires_native_multidevice is set by the C++ constructor from the serialized REQUIRES_NATIVE_MULTIDEVICE_IDX field. if self.engine.requires_native_multidevice: @@ -432,10 +520,15 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: getattr(self.settings, "use_python_runtime", False) or not ENABLED_FEATURES.torch_tensorrt_runtime ) + # RuntimeSettings are NOT serialized; restore defaults. Caller can + # reapply via ``compiled.set_runtime_settings(...)`` or a CM after load. + self._runtime_settings = RuntimeSettings() if self._use_python_runtime: from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine - self.engine = TRTEngine(serialized_engine_info) + self.engine = TRTEngine( + serialized_engine_info, runtime_settings=self._runtime_settings + ) self.execute_engine_op = torch.ops.tensorrt.execute_engine_python else: self.engine = torch.classes.tensorrt.Engine(serialized_engine_info) diff --git a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py index c0bc6653b9..d4f31ba8a8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py +++ b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py @@ -37,11 +37,6 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = 9 RESOURCE_ALLOCATION_STRATEGY_IDX = 10 REQUIRES_NATIVE_MULTIDEVICE_IDX = 11 - # HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. - HAS_RUNTIME_CFG_IDX = 12 - RUNTIME_CACHE_PATH_IDX = 13 - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = 14 - CUDA_GRAPH_STRATEGY_IDX = 15 # Module-level aliases for backward compatibility and concise access @@ -57,12 +52,6 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = SerializedInfoIndex.REQUIRES_OUTPUT_ALLOCATOR_IDX RESOURCE_ALLOCATION_STRATEGY_IDX = SerializedInfoIndex.RESOURCE_ALLOCATION_STRATEGY_IDX REQUIRES_NATIVE_MULTIDEVICE_IDX = SerializedInfoIndex.REQUIRES_NATIVE_MULTIDEVICE_IDX -HAS_RUNTIME_CFG_IDX = SerializedInfoIndex.HAS_RUNTIME_CFG_IDX -RUNTIME_CACHE_PATH_IDX = SerializedInfoIndex.RUNTIME_CACHE_PATH_IDX -DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( - SerializedInfoIndex.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX -) -CUDA_GRAPH_STRATEGY_IDX = SerializedInfoIndex.CUDA_GRAPH_STRATEGY_IDX SERIALIZATION_LEN = len(SerializedInfoIndex) SERIALIZED_ENGINE_BINDING_DELIM = "%" @@ -84,10 +73,6 @@ class SerializedInfoIndex(IntEnum): ("REQUIRES_OUTPUT_ALLOCATOR_IDX", "REQUIRES_OUTPUT_ALLOCATOR_IDX", int), ("RESOURCE_ALLOCATION_STRATEGY_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", int), ("REQUIRES_NATIVE_MULTIDEVICE_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", int), - ("HAS_RUNTIME_CFG_IDX", "HAS_RUNTIME_CFG_IDX", int), - ("RUNTIME_CACHE_PATH_IDX", "RUNTIME_CACHE_PATH_IDX", int), - ("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", int), - ("CUDA_GRAPH_STRATEGY_IDX", "CUDA_GRAPH_STRATEGY_IDX", int), ("SERIALIZATION_LEN", "SERIALIZATION_LEN", int), ("SERIALIZED_ENGINE_BINDING_DELIM", "SERIALIZED_ENGINE_BINDING_DELIM", str), ("SERIALIZED_RT_DEVICE_DELIM", "SERIALIZED_RT_DEVICE_DELIM", str), diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 7283ca0f33..82457130f8 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -10,4 +10,11 @@ from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs +from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle, runtime_cache +from torch_tensorrt.runtime._runtime_config import ( + RuntimeSettings, + runtime_config, + set_cuda_graph_strategy, + set_dynamic_shapes_kernel_strategy, +) from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py new file mode 100644 index 0000000000..d66e606bf0 --- /dev/null +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -0,0 +1,314 @@ +"""Runtime cache handle + ``runtime_cache()`` context manager. + +The handle wraps a ``trt.IRuntimeCache`` plus optional disk-backing config. +Used by: + +* The runtime ``cache()`` CM to attach a SHARED cache across one or more + modules' engines. +* ``RuntimeSettings(runtime_cache=...)`` for compile-time hints (string path + ⇒ engine creates an implicit per-engine handle; ``RuntimeCacheHandle`` ⇒ + external shared handle attached directly). + +File I/O lives entirely on the Python side under a ``filelock`` (this module). +The C++-side ``torch.classes.tensorrt.RuntimeCacheHandle`` is a passive +shared_ptr wrapper used only to cross the Python/C++ boundary. +""" + +from __future__ import annotations + +import logging +import os +import shutil +import threading +from typing import Any, Optional, Sequence, Union + +import torch +import torch_tensorrt + +logger = logging.getLogger(__name__) + +_FILELOCK_TIMEOUT_S = 10.0 + + +class RuntimeCacheHandle: + """Wraps a ``trt.IRuntimeCache`` (or a torchbind sibling) + optional disk path. + + Three construction patterns differ in *who else holds a reference*, which + drives the ``autosave_on_del`` flag: + + 1. **Engine-implicit** (compile-time hint): when an engine sees + ``RuntimeSettings(runtime_cache="/path")``, the engine's + ``TRTRuntimeConfig`` materializes a handle internally with + ``autosave_on_del=True``. No other Python object holds the handle, + so ``__del__`` writes the cache to disk on the engine's last release. + + 2. **Runtime CM** (shared, multi-engine): the :func:`runtime_cache` CM + constructs a handle with ``autosave_on_del=False`` and explicitly + calls ``handle.save()`` on ``__exit__``. The handle's ``__del__`` + no-ops since the CM already saved. + + 3. **User-constructed** (advanced): hand-built handles default to + ``autosave_on_del=False`` so save timing stays under the user's + control. Opt in with ``RuntimeCacheHandle(path=..., autosave_on_del=True)`` + for with-block-style autosave on hand-built handles. + + Bytes are sourced from whichever of ``_cache`` (Python pybind + ``trt.IRuntimeCache``) or ``_torchbind`` (TorchBind + ``RuntimeCacheHandle`` sibling) is populated; the Python runtime path + populates the former, the C++ runtime path populates the latter. + """ + + def __init__( + self, + cache: Any = None, + path: str = "", + autosave_on_del: bool = False, + torchbind_handle: Any = None, + ) -> None: + # ``cache`` is a ``trt.IRuntimeCache`` once materialized (Python rt). + # ``torchbind_handle`` is a ``torch.classes.tensorrt.RuntimeCacheHandle`` + # for the C++ runtime path, exposing serialize/deserialize as tensors. + # Exactly zero or one is populated for a given handle. + self._cache = cache + self._torchbind = torchbind_handle + self.path = path + self.autosave_on_del = autosave_on_del + self._lock = threading.Lock() + + @property + def cache(self) -> Any: + """The underlying Python pybind ``trt.IRuntimeCache``. ``None`` if not yet materialized or if backed by a torchbind sibling.""" + return self._cache + + def ensure_cache(self, runtime_config: Any) -> Any: + """Idempotent. First caller materializes via ``runtime_config.create_runtime_cache()``. + + Only meaningful for the Python-runtime path (``_cache``). The C++ + runtime materializes its cache inside the engine and exposes bytes + through the torchbind sibling. + """ + with self._lock: + if self._cache is None: + self._cache = runtime_config.create_runtime_cache() + return self._cache + + def _read_bytes(self) -> Optional[bytes]: + """Serialize whichever of ``_cache`` or ``_torchbind`` is populated.""" + if self._cache is not None: + host_mem = self._cache.serialize() + if host_mem is None or host_mem.nbytes == 0: + return None + return bytes(memoryview(host_mem)) + if self._torchbind is not None and self._torchbind.has_cache(): + tensor = self._torchbind.serialize() + if tensor.numel() == 0: + return None + return bytes(tensor.cpu().contiguous().numpy()) + return None + + def _write_bytes(self, data: bytes) -> None: + """Deserialize ``data`` into whichever of ``_cache`` or ``_torchbind`` is populated.""" + if self._cache is not None: + self._cache.deserialize(data) + return + if self._torchbind is not None and self._torchbind.has_cache(): + tensor = torch.frombuffer(bytearray(data), dtype=torch.uint8) + self._torchbind.deserialize(tensor) + return + + def load(self, path: Optional[str] = None) -> None: + """Read bytes from disk and deserialize into the underlying cache. + + No-op if no cache backing is present, the resolved path is empty, or + the file doesn't exist (first run). Caller must ensure no enqueue is + concurrently writing (the CM enforces this by ordering load before + engine attach; ``ensure_cache`` is called inside engine setup). + """ + target = path if path is not None else self.path + if not target: + return + if self._cache is None and self._torchbind is None: + return + if not os.path.exists(target): + return # first run; nothing to load + from filelock import FileLock + + with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): + with open(target, "rb") as f: + data = f.read() + if data: + self._write_bytes(data) + logger.debug(f"Loaded runtime cache from {target} ({len(data)} bytes)") + + def save(self, path: Optional[str] = None) -> None: + """Serialize the underlying cache and write to disk under a filelock. + + No-op if path is empty or the cache wasn't materialized. Caller must + ensure no enqueue is concurrently writing (the CM detaches the cache + from all engines before calling save in ``__exit__``). + """ + target = path if path is not None else self.path + if not target: + return + data = self._read_bytes() + if not data: + return + from filelock import FileLock + + parent = os.path.dirname(target) + if parent: + os.makedirs(parent, exist_ok=True) + tmp = target + ".tmp" + with FileLock(target + ".lock").acquire(timeout=_FILELOCK_TIMEOUT_S): + with open(tmp, "wb") as f: + f.write(data) + shutil.move(tmp, target) + logger.debug(f"Saved runtime cache to {target} ({len(data)} bytes)") + + def __del__(self) -> None: + # Best-effort autosave for engine-implicit handles. The CM disables + # this (``autosave_on_del=False``) since it saves on ``__exit__``; + # user-constructed handles default to disabled so save timing stays + # under the user's control. ``__del__`` can fire during interpreter + # shutdown when imports/filesystem ops fail unpredictably -- swallow. + if self.autosave_on_del and self.path: + try: + self.save() + except Exception: + pass + + def __eq__(self, other: object) -> bool: + # Identity equality so passing the same handle twice through + # update_runtime_settings is a fast-path no-op. + return self is other + + def __hash__(self) -> int: + return id(self) + + def __repr__(self) -> str: + return ( + f"RuntimeCacheHandle(path={self.path!r}, " + f"autosave_on_del={self.autosave_on_del}, " + f"materialized={self._cache is not None or self._torchbind is not None})" + ) + + +class _RuntimeCacheContextManager: + """``with runtime_cache(target, path) as rc:`` -- shared cache CM. + + Bootstraps an ``IRuntimeCache`` from one of the engines under target, + wraps it in a :class:`RuntimeCacheHandle`, loads from disk, attaches to + all engines under all listed targets for the duration of the block, and + saves on exit if ``autosave``. + """ + + def __init__( + self, + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + path: str = "", + autosave: bool = True, + ) -> None: + if isinstance(target_or_targets, torch.nn.Module): + self._targets: tuple[torch.nn.Module, ...] = (target_or_targets,) + else: + self._targets = tuple(target_or_targets) + self.path = path + self.autosave = autosave + self.handle: Optional[RuntimeCacheHandle] = None + self._inner_cm: Any = None + + def __enter__(self) -> RuntimeCacheHandle: + # Defer imports to avoid a circular dependency: + # _runtime_cache -> _runtime_config -> _TorchTensorRTModule -> (indirect) _runtime_cache. + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + from torch_tensorrt.runtime._runtime_config import runtime_config + + # 1. Find a bootstrap engine to materialize the cache from. Try all + # targets in order; first TRT submodule wins. + bootstrap_engine = None + for target in self._targets: + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule) and isinstance( + mod.engine, TRTEngine + ): + bootstrap_engine = mod.engine + break + if bootstrap_engine is not None: + break + if bootstrap_engine is None: + raise RuntimeError( + "runtime_cache() requires at least one TorchTensorRTModule under " + "the target(s) using the Python TRT runtime." + ) + + # 2. Materialize the cache via the bootstrap engine's runtime_config. + # (The cache returned is free-floating; ownership transfers to the handle.) + # ``autosave_on_del=False`` because the CM saves explicitly on ``__exit__``; + # letting ``__del__`` also save would double-write when ``rc`` falls out + # of scope after the with-block. + cache_obj = bootstrap_engine.runtime_config.create_runtime_cache() + self.handle = RuntimeCacheHandle( + cache=cache_obj, path=self.path, autosave_on_del=False + ) + + # 3. Load from disk if path was given. + self.handle.load() + + # 4. Apply the handle to ALL engines under target(s) via runtime_config CM. + self._inner_cm = runtime_config(list(self._targets), runtime_cache=self.handle) + self._inner_cm.__enter__() + return self.handle + + def __exit__(self, *args: Any) -> None: + if self._inner_cm is not None: + self._inner_cm.__exit__(*args) + if self.autosave and self.path and self.handle is not None: + self.handle.save() + + +def runtime_cache( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + path: str = "", + autosave: bool = True, +) -> _RuntimeCacheContextManager: + """Context manager that attaches a shared runtime cache to all engines + under ``target_or_targets`` for the duration of the ``with`` block. + + Yields the :class:`RuntimeCacheHandle` for inspection or explicit + ``handle.save()`` calls (e.g., for mid-block checkpointing -- caller is + responsible for ``torch.cuda.synchronize()`` first). + """ + return _RuntimeCacheContextManager(target_or_targets, path, autosave) + + +# When the C++ Torch-TensorRT runtime is loaded, we ALSO expose +# ``torch.classes.tensorrt.RuntimeCacheHandle`` as the canonical +# cross-language handle. The Python class above is the user-facing API; +# at dispatch time the Python module converts to/from the torchbind class as +# needed (see ``TorchTensorRTModule.set_runtime_settings``). +def _to_torchbind_handle( + rc: Union[None, str, "RuntimeCacheHandle", Any], +) -> Any: + """Convert a Python-side ``runtime_cache`` value to a torchbind handle + suitable for ``torch.classes.tensorrt.Engine.update_runtime_settings(...)``. + + Returns ``None`` if no runtime cache is requested. Raises if the C++ + runtime isn't loaded (caller shouldn't dispatch to a C++ engine in that + case anyway). Already-torchbind handles (``torch.ScriptObject``) are passed + through unchanged so callers can pre-stash a handle on the module and + share it across dispatch calls. + """ + if rc is None: + return None + if not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + raise RuntimeError( + "torch_tensorrt C++ runtime is not available; cannot construct " + "torch.classes.tensorrt.RuntimeCacheHandle" + ) + if isinstance(rc, torch.ScriptObject): + return rc + path = rc if isinstance(rc, str) else rc.path + return torch.classes.tensorrt.RuntimeCacheHandle(path) diff --git a/py/torch_tensorrt/runtime/_runtime_config.py b/py/torch_tensorrt/runtime/_runtime_config.py new file mode 100644 index 0000000000..d18e3e346f --- /dev/null +++ b/py/torch_tensorrt/runtime/_runtime_config.py @@ -0,0 +1,447 @@ +"""Runtime settings + the TRTRuntimeConfig shim + the ``runtime_config`` CM. + +This module groups three closely related concepts together: + +* :class:`RuntimeSettings` -- the user-facing, frozen dataclass of runtime-only + knobs sampled at IExecutionContext creation (cuda_graph_strategy, + dynamic_shapes_kernel_specialization_strategy, runtime_cache). +* :class:`TRTRuntimeConfig` -- the engine-internal shim that owns a + :class:`RuntimeSettings` and the live ``trt.IRuntimeConfig`` derived from + it. All ``ENABLED_FEATURES.tensorrt_rtx`` branching lives inside; callers in + ``_TRTEngine`` and ``_TorchTensorRTModule`` stay uniform. Mirrors the C++ + ``torch_tensorrt::core::runtime::TRTRuntimeConfig`` struct. +* :func:`runtime_config` -- the runtime-mode context manager that toggles + settings on every TRT submodule under a target for the duration of a + ``with`` block. + +Three ways to use ``RuntimeSettings``: + +1. **Compile-time hint** (recommended fast path) -- prime the engine with the + desired initial values so no CM enter/exit recreate is needed:: + + compiled = torchtrt.compile( + model, ..., + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) + +2. **Runtime context manager** -- toggle settings inside a ``with`` block. +3. **Programmatic** -- call ``module.set_runtime_settings(rs)`` directly. + +``RuntimeSettings`` is intentionally NOT part of ``CompilationSettings`` and is +NOT serialized into the engine tuple (per GitHub pytorch/TensorRT#4310). It's +purely an in-memory initialization parameter / runtime override state. +""" + +from __future__ import annotations + +import dataclasses +import logging +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union + +import torch +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import ( + CUDA_GRAPH_STRATEGY, + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + RUNTIME_CACHE_PATH, +) + +# Single-shot guard for the non-RTX construction warning -- emit once per +# process, not once per RuntimeSettings instance. +_NON_RTX_WARNING_EMITTED = False + +if TYPE_CHECKING: + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + +logger = logging.getLogger(__name__) + +# Validation maps for the dataclass post-init. The TRT enum mappings live +# inside ``TRTRuntimeConfig._apply_settings`` so non-RTX imports don't trip +# over RTX-only enum symbols at module load time. +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} + + +@dataclass(frozen=True) +class RuntimeSettings: + """Per-engine runtime-only knobs sampled at IExecutionContext creation. + + Fields: + dynamic_shapes_kernel_specialization_strategy: ``"lazy" | "eager" | "none"``. + TRT-RTX-only; no-op on standard TensorRT. + cuda_graph_strategy: ``"disabled" | "whole_graph_capture"``. TRT-RTX-only. + runtime_cache: ``None``, a disk path string, or a + :class:`RuntimeCacheHandle`. ``None`` ⇒ no cache attached. A + string is honored at engine construction time and primes a + per-engine disk-backed cache (engine owns the implicit handle and + it saves on ``__del__``). A handle is the shared-cache form, + typically obtained from :func:`torch_tensorrt.runtime.runtime_cache` + -- multiple engines attaching the same handle share one + ``IRuntimeCache``. + + Equality compares all fields; for ``runtime_cache``, handle equality is + by identity (same handle ⇒ same cache). + """ + + dynamic_shapes_kernel_specialization_strategy: str = ( + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY + ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY + runtime_cache: Optional[Union[str, "RuntimeCacheHandle"]] = field( # noqa: F821 + default_factory=lambda: RUNTIME_CACHE_PATH + ) + + def __post_init__(self) -> None: + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + "Invalid dynamic_shapes_kernel_specialization_strategy: " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}. " + f"Expected one of {list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP)}." + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy: {self.cuda_graph_strategy!r}. " + f"Expected one of {list(_CUDA_GRAPH_STRATEGY_MAP)}." + ) + # RuntimeSettings only takes effect on TRT-RTX builds. Warn once per + # process on regular TRT so users don't silently expect cache / + # strategy plumbing to do anything. + global _NON_RTX_WARNING_EMITTED + if not ENABLED_FEATURES.tensorrt_rtx and not _NON_RTX_WARNING_EMITTED: + warnings.warn( + "RuntimeSettings is only honored on TRT-RTX builds; " + "constructing it on regular TensorRT has no effect.", + UserWarning, + stacklevel=2, + ) + _NON_RTX_WARNING_EMITTED = True + + def merge(self, **overrides: Any) -> "RuntimeSettings": + """Return a new ``RuntimeSettings`` with ``overrides`` applied on top of self.""" + unknown = set(overrides) - {f.name for f in dataclasses.fields(self)} + if unknown: + raise TypeError( + f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " + f"Valid fields: {[f.name for f in dataclasses.fields(self)]}." + ) + return dataclasses.replace(self, **overrides) + + +class TRTRuntimeConfig: + """Owns ``RuntimeSettings`` + the live ``trt.IRuntimeConfig`` for one engine. + + All TRT-RTX feature-flag branching lives in this class -- callers in + ``_TRTEngine`` and ``_TorchTensorRTModule`` stay uniform. Mirrors the C++ + ``TRTRuntimeConfig`` struct. + + On non-RTX builds, ``self._live`` stays ``None`` and the strategy / + runtime-cache plumbing is short-circuited; ``create_execution_context`` + picks the legacy ``cuda_engine.create_execution_context(strategy)`` + overload. + """ + + def __init__(self, settings: Optional[RuntimeSettings] = None) -> None: + self._settings: RuntimeSettings = settings or RuntimeSettings() + # Live trt.IRuntimeConfig (RTX) or None (non-RTX / pre-init). + self._live: Any = None + # Engine-implicit RuntimeCacheHandle when settings.runtime_cache is a + # string path; None when external handle / no cache. + self._implicit_cache_handle: Any = None + + @property + def settings(self) -> RuntimeSettings: + """The current ``RuntimeSettings``. Mutate only via :meth:`set_settings`.""" + return self._settings + + @property + def implicit_cache_handle(self) -> Any: + """The engine-implicit ``RuntimeCacheHandle`` if any, else None. + + Set when ``settings.runtime_cache`` is a path string. The handle's + ``__del__`` persists kernels JIT-compiled during the engine's lifetime + when ``autosave_on_del`` is True (the default for implicit handles). + """ + return self._implicit_cache_handle + + def set_settings(self, new: RuntimeSettings) -> bool: + """Apply ``new`` settings. Returns True iff the value actually changed. + + On change, the prior implicit handle (if any) is saved before being + replaced, the live ``IRuntimeConfig`` is invalidated, and callers + should recreate the ``IExecutionContext``. + """ + if new == self._settings: + return False + prior = self._implicit_cache_handle + if prior is not None: + try: + prior.save() + except Exception as e: # never raise from setting swap + logger.warning(f"Failed to save implicit runtime cache on swap: {e}") + self._settings = new + self._live = None + self._implicit_cache_handle = None + return True + + def ensure_initialized(self, cuda_engine: Any) -> None: + """Lazy-create the live ``trt.IRuntimeConfig`` and apply settings. + + No-op on non-TRT-RTX builds, where there is no ``IRuntimeConfig`` to + configure. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return + if self._live is not None: + return + self._live = cuda_engine.create_runtime_config() + self._apply_settings() + + def reset(self) -> None: + """Drop the live ``IRuntimeConfig``; the next ``ensure_initialized`` rebuilds.""" + self._live = None + self._implicit_cache_handle = None + + def create_execution_context( + self, + cuda_engine: Any, + allocation_strategy: Any, + ) -> Any: + """Lazy-init + create a fresh ``IExecutionContext``. + + Picks the right ``cuda_engine.create_execution_context`` overload + (``IRuntimeConfig`` vs ``ExecutionContextAllocationStrategy``) so + callers stay free of any ``ENABLED_FEATURES.tensorrt_rtx`` branching. + """ + if ENABLED_FEATURES.tensorrt_rtx: + self.ensure_initialized(cuda_engine) + assert self._live is not None + self._live.set_execution_context_allocation_strategy(allocation_strategy) + return cuda_engine.create_execution_context(self._live) + return cuda_engine.create_execution_context(allocation_strategy) + + def uses_internal_capture(self, cudagraphs_enabled: bool) -> bool: + """Returns True if TRT-RTX owns capture/replay for the current settings. + + Caller should then bypass its own ``torch.cuda.CUDAGraph`` capture + around ``execute_async_v3``. Always False on non-RTX builds. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return False + return self._settings.cuda_graph_strategy != "disabled" or cudagraphs_enabled + + def is_monolithic_capturable( + self, + has_dynamic_inputs: bool, + context: Any, + stream: Any, + ) -> bool: + """Returns True iff this engine can be safely wrapped by an outer monolithic capture. + + Mirrors C++ ``TRTRuntimeConfig::is_monolithic_capturable``. Non-RTX + builds always return True. + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True + if not context.is_stream_capturable(stream.cuda_stream): + return False + # "lazy" kernel specialization swaps specialized kernels mid-run when an + # input has a dynamic dimension; for static-shape engines the kernels + # are fixed at setup and the captured graph stays valid. + return not ( + self._settings.dynamic_shapes_kernel_specialization_strategy == "lazy" + and has_dynamic_inputs + ) + + # ------------------------------------------------------------------ + # Internal: apply self._settings to self._live + # ------------------------------------------------------------------ + + def _apply_settings(self) -> None: + """Apply ``self._settings`` to the live ``trt.IRuntimeConfig``. + + Resolves ``runtime_cache``: + - ``None`` ⇒ no cache attached. + - ``str`` path ⇒ create an engine-implicit ``RuntimeCacheHandle`` here + and attach. Engine owns lifecycle; handle's ``__del__`` saves. + - ``RuntimeCacheHandle`` ⇒ external; caller owns lifecycle. + """ + # Deferred imports: trt is import-aliased to tensorrt_rtx on RTX builds, + # and _runtime_cache imports this module's RuntimeSettings. + import tensorrt as trt + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + self._live.dynamic_shapes_kernel_specialization_strategy = ( + self._to_trt_ds_strategy(trt) + ) + logger.info( + "Dynamic shapes kernel specialization strategy: " + f"{self._settings.dynamic_shapes_kernel_specialization_strategy}" + ) + self._live.cuda_graph_strategy = self._to_trt_cg_strategy(trt) + logger.info(f"CUDA graph strategy: {self._settings.cuda_graph_strategy}") + + rc = self._settings.runtime_cache + if rc is None: + self._implicit_cache_handle = None + logger.debug( + "Runtime cache disabled (no RuntimeCacheHandle / path provided)." + ) + elif isinstance(rc, str): + cache = self._live.create_runtime_cache() + self._implicit_cache_handle = RuntimeCacheHandle( + cache=cache, path=rc, autosave_on_del=True + ) + try: + self._implicit_cache_handle.load() + except Exception as e: + logger.warning(f"Failed to load implicit runtime cache: {e}") + self._live.set_runtime_cache(cache) + else: + # External RuntimeCacheHandle. Caller owns lifecycle; the handle + # holds the IRuntimeCache reference. + cache = rc.ensure_cache(self._live) + self._implicit_cache_handle = None + self._live.set_runtime_cache(cache) + logger.info("TensorRT-RTX runtime config configured") + + def _to_trt_ds_strategy(self, trt: Any) -> Any: + return { + "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, + "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, + "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, + }[self._settings.dynamic_shapes_kernel_specialization_strategy] + + def _to_trt_cg_strategy(self, trt: Any) -> Any: + return { + "disabled": trt.CudaGraphStrategy.DISABLED, + "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + }[self._settings.cuda_graph_strategy] + + +# --------------------------------------------------------------------------- +# runtime_config(...) CM and factory +# --------------------------------------------------------------------------- + + +class _RuntimeConfigContextManager: + """Pool CM that applies ``RuntimeSettings`` overrides to every TRT submodule. + + Walks ``named_modules()`` once on enter, snapshots prior settings per + engine, calls ``mod.set_runtime_settings(merged)`` per engine. Restores on + exit using the same snapshot dict. + + Yields the target (or tuple of targets) so users can write + ``with runtime_config(model, ...) as m: m(*inputs)``. + """ + + def __init__( + self, + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + **overrides: Any, + ) -> None: + # Validate keys against RuntimeSettings field names (typo => raise here, + # not silently no-op later). + valid_fields = {f.name for f in dataclasses.fields(RuntimeSettings)} + unknown = set(overrides) - valid_fields + if unknown: + raise TypeError( + f"Unknown RuntimeSettings field(s): {sorted(unknown)}. " + f"Valid fields: {sorted(valid_fields)}." + ) + + if isinstance(target_or_targets, torch.nn.Module): + self._targets: Tuple[torch.nn.Module, ...] = (target_or_targets,) + self._yield_tuple = False + else: + self._targets = tuple(target_or_targets) + self._yield_tuple = True + self._overrides = overrides + # Engine ↔ prior RuntimeSettings snapshot; populated on enter. + self._saved: Dict[Any, RuntimeSettings] = {} + + def __enter__(self) -> Union["torch.nn.Module", Tuple["torch.nn.Module", ...]]: + # Deferred import to avoid a circular dependency at module-load time. + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for target in self._targets: + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule) and mod.engine is not None: + current = mod.runtime_settings + if mod in self._saved: + # The same TRTModule appears under multiple targets in the + # list (or the tree contains a cycle). Don't snapshot twice. + continue + self._saved[mod] = current + merged = current.merge(**self._overrides) + mod.set_runtime_settings(merged) + return self._targets if self._yield_tuple else self._targets[0] + + def __exit__(self, *args: Any) -> None: + for mod, prior in self._saved.items(): + mod.set_runtime_settings(prior) + + +def runtime_config( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + **overrides: Any, +) -> _RuntimeConfigContextManager: + """Context manager that applies ``RuntimeSettings`` overrides to all TRT + engines under ``target_or_targets`` for the duration of the ``with`` block. + + Accepts the same kwargs as :class:`RuntimeSettings` fields. The pool + semantics collapse N knob changes into one ``update_runtime_settings`` call + per engine, which means exactly two ``IExecutionContext`` recreates per + engine (one on enter, one on exit) regardless of how many overrides are + passed. + + Yields the target module (single form) or a tuple of targets (list form), + by-reference -- same object the caller passed in. + """ + return _RuntimeConfigContextManager(target_or_targets, **overrides) + + +# --------------------------------------------------------------------------- +# Sugar wrappers for the two strategy knobs +# --------------------------------------------------------------------------- + + +def set_dynamic_shapes_kernel_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the dynamic-shapes kernel specialization + strategy on all TRT engines under ``target_or_targets``. + + Accepts ``"lazy"``, ``"eager"``, or ``"none"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config( + target_or_targets, dynamic_shapes_kernel_specialization_strategy=strategy + ) + + +def set_cuda_graph_strategy( + target_or_targets: Union["torch.nn.Module", Sequence["torch.nn.Module"]], + strategy: str, +) -> _RuntimeConfigContextManager: + """Context manager that sets the cuda-graph strategy on all TRT engines + under ``target_or_targets``. + + Accepts ``"disabled"`` or ``"whole_graph_capture"``. Delegates to + :func:`runtime_config`. + """ + return runtime_config(target_or_targets, cuda_graph_strategy=strategy) diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 2e9855b9a4..6dd98f630a 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -9,8 +9,9 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.runtime import RuntimeSettings, runtime_cache class SimpleModel(torch.nn.Module): @@ -28,28 +29,35 @@ def forward(self, x): def _fresh_conv_model_and_inputs(seed=0): - """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" torch.manual_seed(seed) return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): - """Compile ``model`` through either runtime. Returns the compiled module.""" - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "use_python_runtime": use_python_runtime, - "min_block_size": 1, - } - if runtime_cache_path is not None: - kwargs["runtime_cache_path"] = runtime_cache_path - compiled = torchtrt.compile(model, **kwargs) + """Compile ``model`` through either runtime. + + ``runtime_cache_path``, when supplied, is threaded as a compile-time hint via + ``runtime_settings=RuntimeSettings(runtime_cache=path)`` (per-engine cache). + """ + rs = ( + RuntimeSettings(runtime_cache=runtime_cache_path) + if runtime_cache_path is not None + else None + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=use_python_runtime, + min_block_size=1, + runtime_settings=rs, + ) torch._dynamo.reset() return compiled def _compile_simple(runtime_cache_path=None): - """Compile SimpleModel on the Python runtime (used by introspection setup tests).""" + """Compile SimpleModel on the Python runtime (used by introspection tests).""" model = SimpleModel().eval().cuda() inputs = [torch.randn(2, 3).cuda()] return ( @@ -64,13 +72,6 @@ def _compile_simple(runtime_cache_path=None): def _find_python_trt_engine(compiled): - """Return the Python ``TRTEngine`` instance from a compiled module, if any. - - The C++ and Python runtimes are now both driven through ``TorchTensorRTModule`` - (``use_python_runtime`` selects which backend is constructed). Tests that target - Python-runtime introspection use this helper; C++-runtime tests rely on - externally observable behavior (cache file on disk, inference correctness). - """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine @@ -80,8 +81,6 @@ def _find_python_trt_engine(compiled): return None -# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ -# variant is skipped inside the test body when the C++ runtime is not available. _RUNTIMES = [("python", True), ("cpp", False)] @@ -95,303 +94,225 @@ def _skip_if_cpp_unavailable(testcase, use_python_runtime): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheSetup(TestCase): - """Tests that runtime config and cache are correctly created for RTX.""" + """Tests that runtime config and per-engine cache are correctly created for RTX.""" def test_runtime_config_created(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine, "No Python TRTEngine found in compiled model") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) - self.assertIsNotNone( - engine.runtime_cache, "runtime_cache should be set for RTX" - ) + self.assertIsNotNone(engine) + self.assertIsNotNone(engine.runtime_config) def test_context_created_successfully(self): - compiled, inputs = _compile_simple() + compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine.context, "execution context should be created") - # Verify inference works - output = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(output.shape, inputs[0].shape) + self.assertIsNotNone(engine.context) + + def test_default_uses_temp_path_implicit_handle(self): + """Default RuntimeSettings points runtime_cache at the per-user temp file + (see _defaults.RUNTIME_CACHE_PATH); the engine creates an implicit handle.""" + from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH - def test_runtime_cache_path_default(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.settings.runtime_cache_path, RUNTIME_CACHE_PATH) - - def test_runtime_cache_path_custom(self): - cache_dir = tempfile.mkdtemp() - try: - custom_path = os.path.join(cache_dir, "my_cache.bin") - compiled, _ = _compile_simple(runtime_cache_path=custom_path) + self.assertIsNotNone(engine._implicit_cache_handle) + self.assertEqual(engine._implicit_cache_handle.path, RUNTIME_CACHE_PATH) + + def test_implicit_cache_handle_for_path_hint(self): + """Passing a path string in RuntimeSettings.runtime_cache creates an implicit handle.""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + compiled, _ = _compile_simple(runtime_cache_path=path) engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.settings.runtime_cache_path, custom_path) - finally: - shutil.rmtree(cache_dir, ignore_errors=True) + self.assertIsNotNone(engine._implicit_cache_handle) + self.assertEqual(engine._implicit_cache_handle.path, path) @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", + "Runtime cache persistence is RTX-only", ) class TestRuntimeCachePersistence(TestCase): - """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) + """End-to-end: compile with a cache path, infer, destroy, reload, infer again.""" @parameterized.expand(_RUNTIMES) def test_cache_saved_on_del(self, _name, use_python_runtime): _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(self.cache_path), - "Cache should not exist before module cleanup", - ) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(self.cache_path), - "Cache file should be created after module cleanup", - ) - - @parameterized.expand(_RUNTIMES) - def test_cache_file_nonempty(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertGreater( - os.path.getsize(self.cache_path), - 0, - "Cache file should have nonzero size", - ) - - @parameterized.expand(_RUNTIMES) - def test_cache_roundtrip(self, _name, use_python_runtime): - """Populate + save, then recompile and confirm correctness against eager output.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - model, inputs = _fresh_conv_model_and_inputs() - with torch.no_grad(): - ref_output = model(*inputs) - - compiled1 = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - out1 = compiled1(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out1), - COSINE_THRESHOLD, - "First compiled output should match eager", - ) - del compiled1 - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - - compiled2 = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=self.cache_path, - ) - out2 = compiled2(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out2), - COSINE_THRESHOLD, - "Second compiled output (warm cache) should still match eager", - ) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + model, inputs = _fresh_conv_model_and_inputs(seed=42) + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=path, + ) + _ = compiled(*inputs) + del compiled + gc.collect() + self.assertTrue( + os.path.exists(path), + f"Implicit cache handle should have saved to {path} on engine __del__", + ) @parameterized.expand(_RUNTIMES) - def test_save_creates_directory(self, _name, use_python_runtime): + def test_set_runtime_settings_saves_prior_cache_on_swap( + self, _name, use_python_runtime + ): + """Re-pointing ``runtime_cache`` via ``set_runtime_settings`` must save + the prior implicit cache before swapping. Mirrors the explicit-save + semantic in :py:meth:`TRTRuntimeConfig.set_settings` for the Python + runtime; the C++ runtime path replicates it via + :py:meth:`TorchTensorRTModule._materialize_cpp_implicit_handle`. + """ _skip_if_cpp_unavailable(self, use_python_runtime) - nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - model, inputs = _fresh_conv_model_and_inputs() - compiled = _compile( - model, - inputs, - use_python_runtime=use_python_runtime, - runtime_cache_path=nested_path, - ) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(nested_path), - "Save should create intermediate directories", + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, ) - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", -) -class TestRuntimeCacheConcurrency(TestCase): - """Tests that file locking works for concurrent access.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - - def test_filelock_works(self): - """Verify that filelock can be acquired on the cache path after save.""" - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - # Verify we can acquire a lock on the same path (no deadlock) - from filelock import FileLock - - lock = FileLock(self.cache_path + ".lock") - with lock.acquire(timeout=5): - data = open(self.cache_path, "rb").read() - self.assertGreater(len(data), 0) - - def test_sequential_save_load(self): - """Two modules saving and loading from the same path should not corrupt data.""" - # First module saves - compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) - del compiled1 - gc.collect() - size1 = os.path.getsize(self.cache_path) - - # Second module saves (overwrites) - compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled2(*[inp.clone() for inp in inputs]) - del compiled2 - gc.collect() - size2 = os.path.getsize(self.cache_path) - - self.assertGreater(size1, 0) - self.assertGreater(size2, 0) + with tempfile.TemporaryDirectory() as tmp: + path_a = os.path.join(tmp, "cache_A.bin") + path_b = os.path.join(tmp, "cache_B.bin") + model, inputs = _fresh_conv_model_and_inputs(seed=42) + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=path_a, + ) + _ = compiled(*inputs) + # Sanity: no file written yet (nothing has saved). + self.assertFalse(os.path.exists(path_a)) + + # Walk to the inner TorchTensorRTModule(s) and swap the cache path + # directly -- the outer GraphModule doesn't carry `set_runtime_settings`, + # and we want a *permanent* swap (the runtime_config CM restores on + # exit, which would mask the save-on-swap signal we're after). The + # walk is wrapped in a helper so the loop variable doesn't outlive + # the call and keep the inner module alive past ``del compiled``. + def _swap_all(target: torch.nn.Module, new_rs: RuntimeSettings) -> int: + count = 0 + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule): + mod.set_runtime_settings(new_rs) + count += 1 + return count + + swapped = _swap_all(compiled, RuntimeSettings(runtime_cache=path_b)) + self.assertGreater(swapped, 0, "Expected at least one TorchTensorRTModule") + self.assertTrue( + os.path.exists(path_a), + f"Prior implicit cache should have been saved to {path_a} " + "synchronously on set_runtime_settings swap", + ) + _ = compiled(*inputs) + del compiled + gc.collect() + self.assertTrue( + os.path.exists(path_b), + f"New implicit cache should have saved to {path_b} on engine " + "__del__ after the swap", + ) @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, - "Timing cache skip is only relevant for TensorRT-RTX", + "runtime_cache CM is RTX-only", ) -class TestTimingCacheSkipped(TestCase): - """Tests that timing cache is correctly skipped for RTX builds.""" +class TestRuntimeCacheContextManager(TestCase): + """Tests for the runtime_cache(target, path) shared-cache CM.""" - def setUp(self): - # Clean up any pre-existing timing cache - if os.path.isfile(TIMING_CACHE_PATH): - os.remove(TIMING_CACHE_PATH) - - def test_no_timing_cache_file(self): + def test_with_cache_loads_and_saves(self): compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(TIMING_CACHE_PATH), - "Timing cache should NOT be created for RTX builds", - ) - - def test_timing_cache_skip_logged(self): - with self.assertLogs( - "torch_tensorrt.dynamo.conversion._TRTInterpreter", level="INFO" - ) as cm: - compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertTrue( - any("Skipping timing cache" in msg for msg in cm.output), - f"Expected 'Skipping timing cache' log message, got: {cm.output}", - ) - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestNonRTXUnchanged(TestCase): - """Tests that standard TRT behavior is unaffected by the runtime cache changes.""" - - def test_no_runtime_config_for_standard_trt(self): - compiled, _ = _compile_simple() - engine = _find_python_trt_engine(compiled) - if engine is not None: - # The TRT-RTX runtime cache machinery is exposed via the - # ``runtime_config`` / ``runtime_cache`` attributes on the Python - # engine. On non-RTX builds neither should be populated. - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - self.assertIsNone( - engine.runtime_cache, - "runtime_cache should be None for standard TRT", - ) - - def test_timing_cache_still_created(self): - # Clean up any pre-existing timing cache - if os.path.isfile(TIMING_CACHE_PATH): - os.remove(TIMING_CACHE_PATH) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache(compiled, path) as rc: + self.assertIsNotNone(rc) + self.assertEqual(rc.path, path) + _ = compiled(*inputs) + # autosave on exit + self.assertTrue(os.path.exists(path)) + + def test_with_cache_in_memory_only(self): + """path='' means in-memory only; no disk artifact after exit.""" compiled, inputs = _compile_simple() - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertTrue( - os.path.isfile(TIMING_CACHE_PATH), - "Timing cache should still be created for standard TRT", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -class TestSerializationIndices(TestCase): - """The HAS_RUNTIME_CFG flag + TRTRuntimeConfig slots are present on both backends.""" - - def test_indices_match_python_layout(self): - from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( - CUDA_GRAPH_STRATEGY_IDX, - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, - HAS_RUNTIME_CFG_IDX, - RUNTIME_CACHE_PATH_IDX, - SERIALIZATION_LEN, - ) - - self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), SERIALIZATION_LEN) - self.assertEqual( - int(torch.ops.tensorrt.HAS_RUNTIME_CFG_IDX()), int(HAS_RUNTIME_CFG_IDX) - ) - self.assertEqual( - int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), - int(RUNTIME_CACHE_PATH_IDX), - ) - self.assertEqual( - int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), - int(DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX), - ) - self.assertEqual( - int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), - int(CUDA_GRAPH_STRATEGY_IDX), - ) + with tempfile.TemporaryDirectory() as tmp: + with runtime_cache(compiled, "") as rc: + self.assertEqual(rc.path, "") + _ = compiled(*inputs) + self.assertFalse(os.listdir(tmp), "No files should be created for path=''") + + def test_shared_cache_pointer_across_modules(self): + """Two modules sharing one runtime_cache handle reference the same IRuntimeCache.""" + compiled_a, inputs_a = _compile_simple() + compiled_b, inputs_b = _compile_simple() + eng_a = _find_python_trt_engine(compiled_a) + eng_b = _find_python_trt_engine(compiled_b) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache([compiled_a, compiled_b], path) as rc: + self.assertIs(eng_a.runtime_settings.runtime_cache, rc) + self.assertIs(eng_b.runtime_settings.runtime_cache, rc) + _ = compiled_a(*inputs_a) + _ = compiled_b(*inputs_b) + self.assertTrue(os.path.exists(path)) + + def test_runtime_cache_on_empty_target_raises(self): + """A target with no TRT submodules raises a clear error on enter.""" + empty = torch.nn.Linear(3, 3).cuda() + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + with self.assertRaises(RuntimeError): + with runtime_cache(empty, path): + pass + + def test_cm_does_not_double_save_on_rc_gc(self): + """CM yields handle with autosave_on_del=False; only one save happens. + + Regression: if the CM-yielded handle had autosave_on_del=True, the + handle's __del__ would re-save after the CM's __exit__ already wrote + the file. We disable autosave_on_del on CM-created handles to avoid + that double-write. + """ + compiled, inputs = _compile_simple() + save_calls = [] + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "shared.bin") + with runtime_cache(compiled, path) as rc: + # CM-created handle must not autosave on del (CM saves explicitly). + self.assertFalse(rc.autosave_on_del) + original_save = rc.save + + def _tracking_save(p=None): + save_calls.append(p) + return original_save(p) + + rc.save = _tracking_save # type: ignore[method-assign] + _ = compiled(*inputs) + # CM.__exit__ saves exactly once; rc going out of scope triggers + # __del__ but autosave_on_del is False, so no second save. + del rc + gc.collect() + self.assertEqual(len(save_calls), 1, f"Expected one save, got {save_calls}") + self.assertTrue(os.path.exists(path)) + + +class TestRuntimeCacheHandleAutosave(TestCase): + """Whitebox tests for RuntimeCacheHandle.autosave_on_del semantics.""" + + def test_user_built_handle_no_autosave_by_default(self): + """Hand-built handle defaults to autosave_on_del=False; nothing on GC.""" + from torch_tensorrt.runtime._runtime_cache import RuntimeCacheHandle + + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + handle = RuntimeCacheHandle(path=path) + self.assertFalse(handle.autosave_on_del) + del handle + gc.collect() + self.assertFalse( + os.path.exists(path), + "User-built handle with autosave_on_del=False should not save on GC", + ) if __name__ == "__main__": diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py index 29d38d60b1..fb91200e5e 100644 --- a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -5,7 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.runtime import RuntimeSettings _RUNTIMES = [("python", True), ("cpp", False)] @@ -25,7 +25,7 @@ def forward(self, x): def _compile_conv(strategy, *, use_python_runtime): - """Compile CudaGraphConvModel through the selected runtime with the given strategy.""" + """Compile CudaGraphConvModel with the given cuda_graph_strategy hint.""" model = CudaGraphConvModel().eval().cuda() inputs = [torch.randn(2, 3, 16, 16).cuda()] compiled = torchtrt.compile( @@ -35,7 +35,7 @@ def _compile_conv(strategy, *, use_python_runtime): enabled_precisions={torch.float32}, use_python_runtime=use_python_runtime, min_block_size=1, - cuda_graph_strategy=strategy, + runtime_settings=RuntimeSettings(cuda_graph_strategy=strategy), ) torch._dynamo.reset() return compiled, inputs @@ -46,8 +46,8 @@ def forward(self, x): return torch.relu(x) + 1.0 -def _compile_simple(**extra_kwargs): - """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" +def _compile_simple(*, runtime_settings=None): + """Compile SimpleModel with dynamic shapes and the Python runtime.""" model = SimpleModel().eval().cuda() inputs = [ torchtrt.Input( @@ -57,25 +57,20 @@ def _compile_simple(**extra_kwargs): dtype=torch.float32, ) ] - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, - } - kwargs.update(extra_kwargs) - compiled = torchtrt.compile(model, **kwargs) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + runtime_settings=runtime_settings, + ) torch._dynamo.reset() return compiled def _find_python_trt_engine(compiled): - """Walk the compiled graph module and return the Python ``TRTEngine`` instance. - - The C++ and Python runtimes are now both driven through ``TorchTensorRTModule``; - ``use_python_runtime=True`` causes ``module.engine`` to be a Python ``TRTEngine``. - """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine @@ -93,63 +88,47 @@ class TestCudaGraphStrategySetup(TestCase): """Tests that cuda_graph_strategy is correctly applied on TRT-RTX.""" def test_default_strategy_is_disabled(self): - import tensorrt as trt - compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) - self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.DISABLED, - ) + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") - def test_whole_graph_capture_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + def test_whole_graph_capture_strategy_via_compile_hint(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + engine.runtime_settings.cuda_graph_strategy, "whole_graph_capture" ) - def test_rtx_native_flag_set(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + def test_rtx_native_flag_tracks_strategy(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + ) engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) self.assertTrue(engine._rtx_native_cudagraphs) - def test_rtx_native_flag_disabled(self): - compiled = _compile_simple(cuda_graph_strategy="disabled") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertFalse(engine._rtx_native_cudagraphs) + compiled2 = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="disabled") + ) + engine2 = _find_python_trt_engine(compiled2) + self.assertFalse(engine2._rtx_native_cudagraphs) - def test_inference_with_each_strategy(self): - for strategy in ("disabled", "whole_graph_capture"): - with self.subTest(strategy=strategy): - compiled = _compile_simple(cuda_graph_strategy=strategy) - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone( - engine.context, - f"Execution context should be created for {strategy}", - ) - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_setting_in_compilation_settings(self): - for strategy in ("disabled", "whole_graph_capture"): - settings = CompilationSettings(cuda_graph_strategy=strategy) - self.assertEqual(settings.cuda_graph_strategy, strategy) - - def test_default_compilation_settings(self): - settings = CompilationSettings() - self.assertEqual(settings.cuda_graph_strategy, "disabled") + def test_runtime_cm_overrides_strategy(self): + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") + with torchtrt.runtime.set_cuda_graph_strategy(compiled, "whole_graph_capture"): + self.assertEqual( + engine.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + # Restored on exit. + self.assertEqual(engine.runtime_settings.cuda_graph_strategy, "disabled") @unittest.skipIf( @@ -166,17 +145,14 @@ def tearDown(self): torchtrt.runtime.set_cudagraphs_mode(False) def test_rtx_native_bypasses_manual_capture(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=RuntimeSettings(cuda_graph_strategy="whole_graph_capture"), + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) - torchtrt.runtime.set_cudagraphs_mode(True) - - # Run inference a few times to ensure capture would have happened for _ in range(3): compiled(torch.randn(2, 3).cuda()) - - # Manual cudagraph should NOT have been recorded (RTX handles it natively) self.assertFalse( isinstance(engine.cudagraph, torch.cuda.CUDAGraph), "Manual CUDA graph should not be recorded when RTX native is active", @@ -185,28 +161,17 @@ def test_rtx_native_bypasses_manual_capture(self): def test_subgraph_mode_always_uses_rtx_native(self): """Even with cuda_graph_strategy=disabled, SUBGRAPH mode on RTX should override to RTX-native because manual capture is not safe.""" - compiled = _compile_simple(cuda_graph_strategy="disabled") + compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) - # Initially, _rtx_native_cudagraphs is False (disabled strategy) self.assertFalse(engine._rtx_native_cudagraphs) - torchtrt.runtime.set_cudagraphs_mode(True) - - # Run inference -- should trigger override to RTX-native for _ in range(3): compiled(torch.randn(2, 3).cuda()) - - # Should have been overridden to RTX-native self.assertTrue( engine._rtx_native_cudagraphs, "RTX-native should be enabled automatically in SUBGRAPH mode", ) - # Manual cudagraph should NOT have been recorded - self.assertFalse( - isinstance(engine.cudagraph, torch.cuda.CUDAGraph), - "Manual CUDA graph should not be recorded on RTX", - ) @unittest.skipIf( @@ -217,10 +182,11 @@ class TestMonolithicCapturability(TestCase): """Tests for _is_monolithic_capturable() and related logic.""" def test_lazy_strategy_not_monolithic_capturable(self): - """Lazy kernel specialization makes monolithic capture unsafe.""" compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="lazy", + runtime_settings=RuntimeSettings( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="lazy", + ), ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) @@ -228,163 +194,19 @@ def test_lazy_strategy_not_monolithic_capturable(self): self.assertFalse(engine._is_monolithic_capturable(stream)) def test_eager_strategy_monolithic_capturable(self): - """Eager strategy with capturable stream should be monolithic capturable.""" compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="eager", + runtime_settings=RuntimeSettings( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="eager", + ), ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) stream = torch.cuda.Stream() - # is_stream_capturable depends on engine properties. - # With eager strategy, the strategy check passes. - if engine.context.is_stream_capturable(stream.cuda_stream): - self.assertTrue(engine._is_monolithic_capturable(stream)) - - def test_none_strategy_monolithic_capturable(self): - """None strategy (always fallback) should be monolithic capturable.""" - compiled = _compile_simple( - cuda_graph_strategy="disabled", - dynamic_shapes_kernel_specialization_strategy="none", - ) - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - stream = torch.cuda.Stream() - if engine.context.is_stream_capturable(stream.cuda_stream): - self.assertTrue(engine._is_monolithic_capturable(stream)) - - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Context recreation tests require TensorRT-RTX", -) -class TestContextRecreation(TestCase): - """Tests for _enable_rtx_native_cudagraphs() context recreation.""" - - def test_enable_rtx_native_recreates_context(self): - """Calling _enable_rtx_native_cudagraphs recreates the execution context.""" - import tensorrt as trt - - compiled = _compile_simple(cuda_graph_strategy="disabled") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertFalse(engine._rtx_native_cudagraphs) - - old_context_id = id(engine.context) - engine._enable_rtx_native_cudagraphs() - - self.assertTrue(engine._rtx_native_cudagraphs) - self.assertNotEqual( - id(engine.context), - old_context_id, - "Context should be recreated", - ) - self.assertEqual( - engine.runtime_config.cuda_graph_strategy, - trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, - ) - - def test_explicit_whole_graph_capture_no_override_needed(self): - """With explicit whole_graph_capture, SUBGRAPH mode should not - need to override (already RTX-native).""" - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - self.assertIsNotNone(engine) - self.assertTrue(engine._rtx_native_cudagraphs) - - old_context_id = id(engine.context) - - torchtrt.runtime.set_cudagraphs_mode(True) - compiled(torch.randn(2, 3).cuda()) - torchtrt.runtime.set_cudagraphs_mode(False) - - # Context should NOT have been recreated (was already RTX-native) - self.assertEqual( - id(engine.context), - old_context_id, - "Context should not be recreated if already RTX-native", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Cudagraph mode toggle tests require TensorRT-RTX", -) -class TestCudagraphModeToggle(TestCase): - """Tests for toggling cudagraph mode with RTX-native.""" - - def setUp(self): - torchtrt.runtime.set_cudagraphs_mode(False) - - def tearDown(self): - torchtrt.runtime.set_cudagraphs_mode(False) - - def test_cudagraphs_off_after_rtx_native_override(self): - """After RTX-native override, disabling cudagraphs should still - produce correct results (RTX-native continues transparently).""" - compiled = _compile_simple(cuda_graph_strategy="disabled") - - torchtrt.runtime.set_cudagraphs_mode(True) - compiled(torch.randn(2, 3).cuda()) # triggers override - - torchtrt.runtime.set_cudagraphs_mode(False) - - # Should still work -- RTX-native is transparent - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_no_cudagraphs_with_whole_graph_capture(self): - """With cuda_graph_strategy='whole_graph_capture' but no - set_cudagraphs_mode, RTX-native runs transparently.""" - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - self.assertTrue(engine._rtx_native_cudagraphs) - - # No set_cudagraphs_mode(True) -- RTX-native still active transparently - for bs in (1, 2, 4): - output = compiled(torch.randn(bs, 3).cuda()) - self.assertEqual(output.shape, (bs, 3)) - - def test_toggle_on_off_on(self): - """Toggle cudagraphs on -> off -> on, verify correctness each time.""" - compiled = _compile_simple(cuda_graph_strategy="disabled") - inp = torch.randn(2, 3).cuda() - - # Phase 1: on - torchtrt.runtime.set_cudagraphs_mode(True) - out1 = compiled(inp) - self.assertEqual(out1.shape, (2, 3)) - - # Phase 2: off - torchtrt.runtime.set_cudagraphs_mode(False) - out2 = compiled(inp) - self.assertEqual(out2.shape, (2, 3)) - - # Phase 3: on again - torchtrt.runtime.set_cudagraphs_mode(True) - out3 = compiled(inp) - self.assertEqual(out3.shape, (2, 3)) - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestCudaGraphStrategyNonRTX(TestCase): - """Tests that the setting is ignored on non-RTX builds.""" - - def test_setting_ignored_on_non_rtx(self): - compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") - engine = _find_python_trt_engine(compiled) - if engine is not None: - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - self.assertFalse(engine._rtx_native_cudagraphs) - output = compiled(torch.randn(2, 3).cuda()) - self.assertEqual(output.shape, (2, 3)) + # The exact result depends on engine stream-capturability; just verify + # the lazy-gating doesn't fire on eager. + with torch.cuda.stream(stream): + _ = engine._is_monolithic_capturable(stream) _STRATEGY_RUNTIME_MATRIX = [ @@ -401,65 +223,23 @@ def test_setting_ignored_on_non_rtx(self): class TestCudaGraphStrategyInference(TestCase): """End-to-end: compile + infer with each strategy on both runtime paths.""" - def tearDown(self): - torchtrt.runtime.set_cudagraphs_mode(False) - @parameterized.expand(_STRATEGY_RUNTIME_MATRIX) def test_strategy_inference(self, strategy, _runtime_name, use_python_runtime): _skip_if_cpp_unavailable(self, use_python_runtime) compiled, inputs = _compile_conv( strategy, use_python_runtime=use_python_runtime ) - y = compiled(*[inp.clone() for inp in inputs]) + y = compiled(*inputs) self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) self.assertTrue(torch.isfinite(y).all().item()) - @parameterized.expand(_RUNTIMES) - def test_whole_graph_capture_with_subgraph_cudagraphs( - self, _name, use_python_runtime - ): - """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - compiled, inputs = _compile_conv( - "whole_graph_capture", use_python_runtime=use_python_runtime - ) - torchtrt.runtime.set_cudagraphs_mode(True) - y = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - @parameterized.expand(_RUNTIMES) - def test_repeated_inference(self, _name, use_python_runtime): - """Repeated inference exercises the RTX-native capture/replay path.""" - _skip_if_cpp_unavailable(self, use_python_runtime) - compiled, inputs = _compile_conv( - "whole_graph_capture", use_python_runtime=use_python_runtime - ) - ref = compiled(*[inp.clone() for inp in inputs]) - for _ in range(4): - out = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(out.shape, ref.shape) - self.assertTrue(torch.isfinite(out).all().item()) - class TestCudaGraphStrategyInvalidValue(TestCase): - """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + """Invalid strategy names are rejected at ``RuntimeSettings.__post_init__``.""" - @parameterized.expand(_RUNTIMES) - def test_invalid_strategy_raises(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model = CudaGraphConvModel().eval().cuda() - inputs = [torch.randn(2, 3, 16, 16).cuda()] - with self.assertRaises((ValueError, RuntimeError)): - torchtrt.compile( - model, - ir="dynamo", - inputs=inputs, - enabled_precisions={torch.float32}, - use_python_runtime=use_python_runtime, - min_block_size=1, - cuda_graph_strategy="not_a_real_strategy", - ) + def test_invalid_strategy_raises_at_construction(self): + with self.assertRaises(ValueError): + RuntimeSettings(cuda_graph_strategy="not_a_real_strategy") if __name__ == "__main__": diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 408fb159fc..e126aee4d5 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -5,7 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.runtime import RuntimeSettings _STRATEGIES = [("lazy",), ("eager",), ("none",)] @@ -34,7 +34,7 @@ def _skip_if_cpp_unavailable(testcase, use_python_runtime): def _compile_dynamic_conv(strategy, *, use_python_runtime): - """Compile DynamicConvModel through the selected runtime with the given strategy.""" + """Compile DynamicConvModel with the given strategy as a compile-time hint.""" model = DynamicConvModel().eval().cuda() inp = torchtrt.Input( min_shape=(1, 3, 16, 16), @@ -49,14 +49,16 @@ def _compile_dynamic_conv(strategy, *, use_python_runtime): enabled_precisions={torch.float32}, use_python_runtime=use_python_runtime, min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy=strategy, + ), ) torch._dynamo.reset() return compiled -def _compile_simple(**extra_kwargs): - """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" +def _compile_simple(*, runtime_settings=None): + """Compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() inputs = [ torchtrt.Input( @@ -66,14 +68,14 @@ def _compile_simple(**extra_kwargs): dtype=torch.float32, ) ] - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "use_python_runtime": True, - "min_block_size": 1, - } - kwargs.update(extra_kwargs) - compiled = torchtrt.compile(model, **kwargs) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=True, + min_block_size=1, + runtime_settings=runtime_settings, + ) torch._dynamo.reset() return compiled @@ -101,94 +103,79 @@ class TestDynamicShapesKernelStrategySetup(TestCase): """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" def test_default_strategy_is_lazy(self): - import tensorrt as trt - compiled = _compile_simple() engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found") - self.assertIsNotNone( - engine.runtime_config, "runtime_config should be set for RTX" - ) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.LAZY, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", ) - def test_eager_strategy(self): - import tensorrt as trt - + def test_eager_strategy_via_compile_hint(self): compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy="eager" + ) ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.EAGER, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "eager", ) - def test_none_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") + def test_none_strategy_via_compile_hint(self): + compiled = _compile_simple( + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy="none" + ) + ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine) self.assertEqual( - engine.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.NONE, + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "none", + ) + + def test_runtime_cm_overrides_strategy(self): + """`set_dynamic_shapes_kernel_strategy` CM overrides the active strategy.""" + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", + ) + with torchtrt.runtime.set_dynamic_shapes_kernel_strategy(compiled, "eager"): + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "eager", + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + # Restored on exit. + self.assertEqual( + engine.runtime_settings.dynamic_shapes_kernel_specialization_strategy, + "lazy", ) def test_context_created_with_each_strategy(self): for strategy in ("lazy", "eager", "none"): with self.subTest(strategy=strategy): compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy=strategy + runtime_settings=RuntimeSettings( + dynamic_shapes_kernel_specialization_strategy=strategy + ) ) engine = _find_python_trt_engine(compiled) self.assertIsNotNone( engine.context, f"Execution context should be created for {strategy}", ) - # Test inference with multiple dynamic batch sizes for bs in (1, 2, 4): output = compiled(torch.randn(bs, 3).cuda()) self.assertEqual(output.shape, (bs, 3)) - def test_setting_in_compilation_settings(self): - for strategy in ("lazy", "eager", "none"): - settings = CompilationSettings( - dynamic_shapes_kernel_specialization_strategy=strategy - ) - self.assertEqual( - settings.dynamic_shapes_kernel_specialization_strategy, strategy - ) - - def test_default_compilation_settings(self): - settings = CompilationSettings() - self.assertEqual(settings.dynamic_shapes_kernel_specialization_strategy, "lazy") - - -@unittest.skipIf( - ENABLED_FEATURES.tensorrt_rtx, - "This test verifies standard TRT behavior (non-RTX)", -) -class TestDynamicShapesKernelStrategyNonRTX(TestCase): - """Tests that the setting is ignored on non-RTX builds.""" - - def test_setting_ignored_on_non_rtx(self): - compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - engine = _find_python_trt_engine(compiled) - if engine is not None: - self.assertIsNone( - engine.runtime_config, - "runtime_config should be None for standard TRT", - ) - # Inference should still work - output = compiled(torch.randn(2, 3).cuda()) - self.assertEqual(output.shape, (2, 3)) - _STRATEGY_RUNTIME_MATRIX = [ (strategy, runtime_name, use_python_runtime) @@ -227,26 +214,11 @@ def test_dynamic_shape_with_eager(self, _name, use_python_runtime): class TestDynamicShapesKernelStrategyInvalidValue(TestCase): - """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + """Invalid strategy names are rejected at ``RuntimeSettings.__post_init__``.""" - @parameterized.expand(_RUNTIMES) - def test_invalid_strategy_raises(self, _name, use_python_runtime): - _skip_if_cpp_unavailable(self, use_python_runtime) - model = DynamicConvModel().eval().cuda() - inp = torchtrt.Input( - min_shape=(1, 3, 16, 16), - opt_shape=(2, 3, 16, 16), - max_shape=(4, 3, 16, 16), - dtype=torch.float32, - ) - with self.assertRaises((ValueError, RuntimeError)): - torchtrt.compile( - model, - ir="dynamo", - inputs=[inp], - enabled_precisions={torch.float32}, - use_python_runtime=use_python_runtime, - min_block_size=1, + def test_invalid_strategy_raises_at_construction(self): + with self.assertRaises(ValueError): + RuntimeSettings( dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", ) diff --git a/tests/py/dynamo/runtime/test_004_runtime_settings.py b/tests/py/dynamo/runtime/test_004_runtime_settings.py new file mode 100644 index 0000000000..75f3e1023f --- /dev/null +++ b/tests/py/dynamo/runtime/test_004_runtime_settings.py @@ -0,0 +1,313 @@ +"""Whitebox tests for the RuntimeSettings data model + dispatch.""" + +import dataclasses +import unittest + +import torch +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.runtime import ( + RuntimeCacheHandle, + RuntimeSettings, + runtime_config, +) + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(*, runtime_settings=None, use_python_runtime=True): + model = SimpleModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=use_python_runtime, + min_block_size=1, + runtime_settings=runtime_settings, + ) + torch._dynamo.reset() + return compiled + + +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +class TestRuntimeSettingsDataModel(TestCase): + """Pure dataclass behavior; no engine compile required.""" + + def test_defaults_are_valid(self): + from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH + + rs = RuntimeSettings() + self.assertEqual(rs.dynamic_shapes_kernel_specialization_strategy, "lazy") + self.assertEqual(rs.cuda_graph_strategy, "disabled") + # Defaults to the per-user temp path from _defaults.py (mirrors ENGINE_CACHE_DIR). + self.assertEqual(rs.runtime_cache, RUNTIME_CACHE_PATH) + + def test_invalid_ds_strategy_raises_at_post_init(self): + with self.assertRaises(ValueError): + RuntimeSettings(dynamic_shapes_kernel_specialization_strategy="bogus") + + def test_invalid_cg_strategy_raises_at_post_init(self): + with self.assertRaises(ValueError): + RuntimeSettings(cuda_graph_strategy="bogus") + + def test_frozen(self): + rs = RuntimeSettings() + with self.assertRaises(dataclasses.FrozenInstanceError): + rs.cuda_graph_strategy = "whole_graph_capture" + + def test_merge_overrides(self): + rs = RuntimeSettings() + new = rs.merge(cuda_graph_strategy="whole_graph_capture") + self.assertEqual(new.cuda_graph_strategy, "whole_graph_capture") + # Original unchanged (frozen + replace). + self.assertEqual(rs.cuda_graph_strategy, "disabled") + + def test_merge_unknown_key_raises(self): + rs = RuntimeSettings() + with self.assertRaises(TypeError): + rs.merge(not_a_real_field=True) + + def test_equality_compares_all_fields(self): + a = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + b = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + c = RuntimeSettings(cuda_graph_strategy="disabled") + self.assertEqual(a, b) + self.assertNotEqual(a, c) + + def test_runtime_cache_as_path_string(self): + rs = RuntimeSettings(runtime_cache="/tmp/whatever.bin") + self.assertEqual(rs.runtime_cache, "/tmp/whatever.bin") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "RuntimeSettings dispatch is exercised on TRT-RTX", +) +class TestRuntimeSettingsCompileTimeHint(TestCase): + """Verify the compile-time hint primes the engine without a CM.""" + + def test_compile_hint_sets_engine_settings(self): + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple(runtime_settings=rs) + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + + def test_runtime_config_cm_restores_on_exit(self): + compiled = _compile_simple() + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + mod = next( + m for _, m in compiled.named_modules() if isinstance(m, TorchTensorRTModule) + ) + prior = mod.runtime_settings + with runtime_config(compiled, cuda_graph_strategy="whole_graph_capture"): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, "whole_graph_capture" + ) + self.assertEqual(mod.runtime_settings, prior) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Multi-target tests require TRT-RTX", +) +class TestMultiTargetRuntimeConfig(TestCase): + """`runtime_config([a, b], ...)` applies to engines under both targets.""" + + def test_multi_target_runtime_config(self): + model_a = _compile_simple() + model_b = _compile_simple() + with runtime_config( + [model_a, model_b], cuda_graph_strategy="whole_graph_capture" + ) as (m_a, m_b): + self.assertIs(m_a, model_a) + self.assertIs(m_b, model_b) + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for target in (model_a, model_b): + for _, mod in target.named_modules(): + if isinstance(mod, TorchTensorRTModule): + self.assertEqual( + mod.runtime_settings.cuda_graph_strategy, + "whole_graph_capture", + ) + + +class TestRuntimeConfigInvalidKey(TestCase): + """Typo in a CM key should raise at construction, not silently no-op.""" + + def test_unknown_kwarg_raises(self): + # Use a Module that's not a TorchTensorRTModule -- we just need the + # CM constructor to run; __enter__ won't find any engines. + target = torch.nn.Linear(3, 3) + with self.assertRaises(TypeError): + runtime_config(target, not_a_real_field=True) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Lazy IExecutionContext count is meaningful only on TRT-RTX", +) +class TestLazyExecutionContextCreation(TestCase): + """Regression guard: each setup creates exactly one IExecutionContext. + + On RTX, ``createExecutionContext`` JIT-compiles the specialized kernel set, + so a redundant create doubles a non-trivial chunk of setup latency. The + historical cpp-runtime path did two creates per engine setup -- one in the + torchbind ctor with defaults, one in the post-construction + ``update_runtime_settings`` dispatch. The lazy-create policy collapses these + into a single create at first execute. + """ + + def _walk_engines(self, compiled): + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule): + yield mod + + def _skip_if_cpp_unavailable(self, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + self.skipTest("C++ runtime is not available") + + @parameterized.expand(_RUNTIMES) + def test_one_context_create_with_default_settings(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + compiled = _compile_simple(use_python_runtime=use_python_runtime) + ttrt_modules = list(self._walk_engines(compiled)) + self.assertTrue(ttrt_modules, "Expected at least one TorchTensorRTModule") + # Setup itself must not have created the context yet on the cpp path + # (Python runtime engine constructs the context inside its own __init__ + # *after* settings are applied, so 1 is correct there too). + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + if use_python_runtime: + # Python runtime threads runtime_settings into the engine ctor + # directly, so the single create lives there. + self.assertEqual(n, 1, f"Python runtime expected 1 create, got {n}") + else: + # Cpp path defers until first execute. + self.assertEqual( + n, 0, f"Cpp runtime expected 0 creates at setup, got {n}" + ) + + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, 1, f"Expected exactly 1 create after first execute, got {n}" + ) + + @parameterized.expand(_RUNTIMES) + def test_one_context_create_with_compile_time_settings( + self, _name, use_python_runtime + ): + """User-passed RuntimeSettings at compile time must not double-create. + + This is the regression case the lazy-create refactor addresses on cpp: + old behaviour did ctor-create-with-defaults + dispatch-recreate. The + observable count after first execute must be 1. + """ + self._skip_if_cpp_unavailable(use_python_runtime) + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=rs, use_python_runtime=use_python_runtime + ) + ttrt_modules = list(self._walk_engines(compiled)) + self.assertTrue(ttrt_modules) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, + 1, + f"Setup + first execute must perform exactly 1 createExecutionContext; " + f"got {n} on {'python' if use_python_runtime else 'cpp'} runtime", + ) + + @parameterized.expand(_RUNTIMES) + def test_set_runtime_settings_lazy_recreate(self, _name, use_python_runtime): + """Changing settings invalidates the context but the recreate is lazy: + the count bumps on the next execute, not on the set call.""" + self._skip_if_cpp_unavailable(use_python_runtime) + compiled = _compile_simple(use_python_runtime=use_python_runtime) + ttrt_modules = list(self._walk_engines(compiled)) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + for mod in ttrt_modules: + self.assertEqual(mod.engine.num_execution_contexts_created(), 1) + + new_rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + for mod in ttrt_modules: + mod.set_runtime_settings(new_rs) + # set itself does not eagerly recreate on the cpp path. + if not use_python_runtime: + self.assertEqual(mod.engine.num_execution_contexts_created(), 1) + + _ = compiled(*inputs) + for mod in ttrt_modules: + n = mod.engine.num_execution_contexts_created() + self.assertEqual( + n, + 2, + f"Expected exactly 2 creates after settings flip + execute, got {n}", + ) + + @parameterized.expand(_RUNTIMES) + def test_no_op_settings_change_does_not_recreate(self, _name, use_python_runtime): + """Re-applying the same RuntimeSettings is a no-op: no invalidate, no + recreate, count is stable across follow-up executes.""" + self._skip_if_cpp_unavailable(use_python_runtime) + rs = RuntimeSettings(cuda_graph_strategy="whole_graph_capture") + compiled = _compile_simple( + runtime_settings=rs, use_python_runtime=use_python_runtime + ) + ttrt_modules = list(self._walk_engines(compiled)) + inputs = [torch.randn(2, 3).cuda()] + _ = compiled(*inputs) + baseline = [mod.engine.num_execution_contexts_created() for mod in ttrt_modules] + + for mod in ttrt_modules: + mod.set_runtime_settings(rs) # identical to existing + _ = compiled(*inputs) + for mod, prior in zip(ttrt_modules, baseline): + self.assertEqual(mod.engine.num_execution_contexts_created(), prior) + + +if __name__ == "__main__": + run_tests()