Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
"DeviceList.cpp",
"Platform.cpp",
"RTDevice.cpp",
"RuntimeSettings.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
"TRTRuntimeConfig.cpp",
Expand All @@ -96,6 +97,7 @@ cc_library(
hdrs = [
"Platform.h",
"RTDevice.h",
"RuntimeSettings.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TRTRuntimeConfig.h",
Expand Down Expand Up @@ -158,6 +160,7 @@ cc_library(
hdrs = [
"Platform.h",
"RTDevice.h",
"RuntimeSettings.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBindingNames.h",
Expand All @@ -174,6 +177,7 @@ filegroup(
srcs = [
"Platform.h",
"RTDevice.h",
"RuntimeSettings.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TRTRuntimeConfig.h",
Expand Down
89 changes: 89 additions & 0 deletions core/runtime/RuntimeSettings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "core/runtime/RuntimeSettings.h"

#include <cstring>
#include <sstream>
#include <tuple>

#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

// ---- RuntimeCacheHandle methods ---------------------------------------------
//
// The ``#ifdef TRT_MAJOR_RTX`` is intentionally confined to this translation
// unit: the public header advertises a uniform interface (always-callable
// methods that simply degrade to no-ops on non-RTX builds), and the JIT-binding
// registration file (``register_jit_hooks.cpp``) calls these as plain member
// references with zero conditional compilation.

at::Tensor RuntimeCacheHandle::serialize() const {
auto opts = at::TensorOptions().dtype(at::kByte);
#ifdef TRT_MAJOR_RTX
if (!cache) {
return at::empty({0}, opts);
}
auto host_mem = make_trt(cache->serialize());
if (!host_mem) {
return at::empty({0}, opts);
}
auto tensor = at::empty({static_cast<int64_t>(host_mem->size())}, opts);
std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size());
return tensor;
#else
return at::empty({0}, opts);
#endif
}

void RuntimeCacheHandle::deserialize(TORCHTRT_UNUSED at::Tensor data) {
#ifdef TRT_MAJOR_RTX
if (data.numel() == 0 || !cache) {
return;
}
auto contig = data.contiguous().to(at::kCPU);
cache->deserialize(contig.data_ptr(), static_cast<size_t>(contig.numel()));
#endif
}

bool RuntimeCacheHandle::has_cache() const {
#ifdef TRT_MAJOR_RTX
return cache != nullptr;
#else
return false;
#endif
}

// ---- RuntimeSettings methods ------------------------------------------------

bool RuntimeSettings::operator==(RuntimeSettings const& other) const noexcept {
// ``runtime_cache`` compares by pointer identity: passing the same handle
// twice through ``update_runtime_settings`` is a no-op. Hoisted into locals
// because ``std::tie`` requires lvalues.
auto* this_cache = runtime_cache.get();
auto* other_cache = other.runtime_cache.get();
return std::tie(dynamic_shapes_kernel_specialization_strategy, cuda_graph_strategy, this_cache) ==
std::tie(other.dynamic_shapes_kernel_specialization_strategy, other.cuda_graph_strategy, other_cache);
}

std::string RuntimeSettings::to_str() const {
std::ostringstream os;
os << "Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_specialization_strategy << std::endl;
os << "CUDA Graph Strategy: " << cuda_graph_strategy << std::endl;
if (runtime_cache) {
auto const& p = runtime_cache->path;
os << "Runtime Cache: " << (p.empty() ? "<in-memory shared>" : p) << std::endl;
} else {
os << "Runtime Cache: <engine-local, in-memory>" << std::endl;
}
return os.str();
}

std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs) {
os << rs.to_str();
return os;
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
77 changes: 77 additions & 0 deletions core/runtime/RuntimeSettings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#pragma once

#include <memory>
#include <ostream>
#include <string>

#include "ATen/core/Tensor.h"
#include "ATen/core/ivalue.h"
#include "NvInfer.h"
#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

// A passive wrapper around an ``IRuntimeCache``. Registered as a torchbind class
// so it can be passed by ``c10::intrusive_ptr`` across the Python/C++ boundary;
// the same handle gives both runtimes the same underlying ``IRuntimeCache*``.
//
// File I/O lives on the Python side (filelock + on-disk persistence via
// the ``serialize`` / ``deserialize`` members below). The C++ struct is purely
// a holder; ``path`` is informational and is not consulted by the C++ runtime.
struct RuntimeCacheHandle : public torch::CustomClassHolder {
std::string path;

#ifdef TRT_MAJOR_RTX
// The actual TensorRT runtime cache. The first engine that attaches this handle
// materializes it via ``IRuntimeConfig::createRuntimeCache()`` and writes the
// shared_ptr here; subsequent engines reuse the same pointer for true sharing.
std::shared_ptr<nvinfer1::IRuntimeCache> cache;
#endif

explicit RuntimeCacheHandle(std::string p = "") : path(std::move(p)) {}

// Expose the underlying ``IRuntimeCache`` bytes for the Python side to persist
// under filelock. Returns an empty uint8 tensor when no cache is attached, or
// on non-RTX builds.
//
// ``at::Tensor`` is used (rather than ``std::string``) because TorchBind
// forces ``std::string`` to round-trip through Python ``str`` (UTF-8), and
// serialized cache bytes are not valid UTF-8.
[[nodiscard]] at::Tensor serialize() const;

// Inverse of ``serialize``. Expects a uint8 ``at::Tensor``. No-op for empty
// input, when the underlying ``IRuntimeCache`` has not been materialized yet,
// or on non-RTX builds.
void deserialize(at::Tensor data);

// True iff an engine has populated the underlying ``IRuntimeCache``.
// Always false on non-RTX builds.
[[nodiscard]] bool has_cache() const;
};

// Per-engine runtime-only knobs sampled at IExecutionContext creation.
//
// ``RuntimeSettings`` is a plain struct (not a torchbind class) because we
// flatten it into positional args at the torchbind boundary -- TorchBind can't
// carry a dataclass natively. Equality is value-by-value; the cache field
// compares by pointer identity (same handle -> same cache).
struct RuntimeSettings {
std::string dynamic_shapes_kernel_specialization_strategy = "lazy";
std::string cuda_graph_strategy = "disabled";
Comment on lines +61 to +62
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to store strings? Can't we store ints here since they will be converted to enumerations anyway? This way we just static cast them to get the enumeration

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tradeoff:

  • Ints: type-safe; static_cast<nvinfer1::DynamicShapesKernelSpecializationStrategy>(settings_.foo) is a no-op; one validation layer.
  • Strings: human-readable to_str() (logged via operator<<) is useful for support reports / debug logs; identical Python-and-C++ value-form for the canonical setting; reasonably grep-able through the codebase.

Today the layout is "strings everywhere, two validation points": Python RuntimeSettings.__post_init__ validates against the string→int map, then TRTRuntimeConfig::ensure_initialized re-validates inside to_trt_ds_strategy / to_trt_cg_strategy. Moving to int internally means: keep the Python-facing API as strings, convert once at the torchbind boundary (update_runtime_settings lambda), store ints, skip the second validation. The cost is that to_str() needs a reverse-lookup table to print "lazy" / "whole_graph_capture".

Happy to push that through if you prefer ints — I held off here because the current form keeps logs human-readable on both sides of the boundary. WDYT?

c10::intrusive_ptr<RuntimeCacheHandle> runtime_cache = nullptr;

bool operator==(RuntimeSettings const& other) const noexcept;
bool operator!=(RuntimeSettings const& other) const noexcept {
return !(*this == other);
}

[[nodiscard]] std::string to_str() const;
};

std::ostream& operator<<(std::ostream& os, RuntimeSettings const& rs);

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
Loading
Loading