Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/flatbuffers/flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ Status LoadValueInfoOrtFormat(const fbs::ValueInfo& fbs_value_info,
Status LoadOpsetImportOrtFormat(const flatbuffers::Vector<flatbuffers::Offset<fbs::OperatorSetId>>* fbs_op_set_ids,
std::unordered_map<std::string, int>& domain_to_version) {
ORT_RETURN_IF(nullptr == fbs_op_set_ids, "Model must have opset imports. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(ValidateRequiredTableOffsets(fbs_op_set_ids, "opset import"));

domain_to_version.clear();
domain_to_version.reserve(fbs_op_set_ids->size());
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/flatbuffers/flatbuffers_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ onnxruntime::common::Status LoadOpsetImportOrtFormat(
const flatbuffers::Vector<flatbuffers::Offset<fbs::OperatorSetId>>* fbs_op_set_ids,
std::unordered_map<std::string, int>& domain_to_version);

template <typename T>
inline onnxruntime::common::Status ValidateRequiredTableOffsets(
const flatbuffers::Vector<flatbuffers::Offset<T>>* fbs_entries,
const char* entry_description) {
if (fbs_entries == nullptr) {
return onnxruntime::common::Status::OK();
}

const auto* raw_offsets = reinterpret_cast<const uint8_t*>(fbs_entries->Data());
for (flatbuffers::uoffset_t i = 0; i < fbs_entries->size(); ++i) {
const auto entry_offset =
flatbuffers::ReadScalar<flatbuffers::uoffset_t>(raw_offsets + i * sizeof(flatbuffers::uoffset_t));
ORT_RETURN_IF(entry_offset == 0, "Null ", entry_description, " entry. Invalid ORT format model.");
}

return onnxruntime::common::Status::OK();
}

// check if filename ends in .ort
bool IsOrtFormatModel(const PathString& filename);

Expand Down
82 changes: 73 additions & 9 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cassert>
#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
#include <queue>
#include <stack>
Expand Down Expand Up @@ -871,7 +872,15 @@
if (fbs_edges) {
for (const auto* fbs_edge : *fbs_edges) {
ORT_RETURN_IF(nullptr == fbs_edge, "Node::LoadEdgesFromOrtFormat, edge is missing for ", dst_name);
edge_set.emplace(*graph.GetNode(fbs_edge->node_index()), fbs_edge->src_arg_index(), fbs_edge->dst_arg_index());
const auto edge_node_index = fbs_edge->node_index();
ORT_RETURN_IF(edge_node_index >= static_cast<uint32_t>(graph.MaxNodeIndex()),
"Node::LoadEdgesFromOrtFormat, ", dst_name, " has out-of-range node index ",
edge_node_index, ". Invalid ORT format model.");
const auto* edge_node = graph.GetNode(edge_node_index);
ORT_RETURN_IF(edge_node == nullptr,
"Node::LoadEdgesFromOrtFormat, ", dst_name, " references missing node ",
edge_node_index, ". Invalid ORT format model.");
edge_set.emplace(*edge_node, fbs_edge->src_arg_index(), fbs_edge->dst_arg_index());
}
}
return Status::OK();
Expand Down Expand Up @@ -6520,8 +6529,10 @@

// Initializers
auto fbs_initializers = fbs_graph.initializers();
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_initializers, "initializer"));
#if !defined(DISABLE_SPARSE_TENSORS)
auto fbs_sparse_initializers = fbs_graph.sparse_initializers();
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_sparse_initializers, "sparse initializer"));
flatbuffers::uoffset_t map_size = (fbs_initializers != nullptr ? fbs_initializers->size() : 0U) +
(fbs_sparse_initializers != nullptr ? fbs_sparse_initializers->size() : 0U);
#else
Expand Down Expand Up @@ -6591,23 +6602,70 @@
// NodeArgs
auto fbs_node_args = fbs_graph.node_args();
if (fbs_node_args) {
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_node_args, "node arg"));
node_args_.reserve(fbs_node_args->size());
for (const auto* fbs_value_info : *fbs_node_args) {
ORT_RETURN_IF(nullptr == fbs_value_info, "NodeArg is missing. Invalid ORT format model.");
NodeArgInfo node_arg_info;
ORT_RETURN_IF_ERROR(fbs::utils::LoadValueInfoOrtFormat(*fbs_value_info, node_arg_info));
const auto* name = fbs_value_info->name();
ORT_RETURN_IF(name == nullptr, "NodeArg name is missing. Invalid ORT format model.");
node_args_[name->str()] = std::make_unique<NodeArg>(std::move(node_arg_info));
const auto inserted = node_args_.emplace(name->str(), std::make_unique<NodeArg>(std::move(node_arg_info)));
ORT_RETURN_IF(!inserted.second, "Duplicate NodeArg name '", name->str(), "'. Invalid ORT format model.");
}
}

// Nodes
//
// Since we access a node using its index, we need to have nodes_ with size max_node_index to avoid
// out of bounds access.
nodes_.resize(fbs_graph.max_node_index());
// Since we access a node using its index, we need to have nodes_ with a size that covers all
// referenced indices. We compute the required slot count from actual node and edge data rather
// than trusting the serialized max_node_index field.
auto* fbs_nodes = fbs_graph.nodes();
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_nodes, "node"));
auto* fbs_node_edges = fbs_graph.node_edges();
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_node_edges, "node edge"));

size_t required_node_slot_count = 0;
const auto update_required_node_slot_count = [&required_node_slot_count](uint32_t node_index) -> Status {
ORT_RETURN_IF(node_index == std::numeric_limits<uint32_t>::max(),
"Node index is out of range. Invalid ORT format model.");
const auto node_slot_count = static_cast<size_t>(node_index) + 1U;
required_node_slot_count = std::max(required_node_slot_count, node_slot_count);
return Status::OK();
};

if (fbs_nodes != nullptr) {
for (const auto* fbs_node : *fbs_nodes) {
ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(update_required_node_slot_count(fbs_node->index()));
}
}

if (fbs_node_edges != nullptr) {
for (const auto* fbs_node_edge : *fbs_node_edges) {
ORT_RETURN_IF(nullptr == fbs_node_edge, "NodeEdge is missing. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(update_required_node_slot_count(fbs_node_edge->node_index()));
}
}

// Sanity bound: reject buffers where a crafted node index would cause a multi-gigabyte
// allocation. After graph optimizations, ORT preserves original node indices (leaving holes
// in the nodes_ vector), so max_node_index can legitimately be much larger than the number
// of remaining nodes. Use a generous multiplier plus an absolute cap to accommodate this
// sparsity while still blocking adversarial amplification.
const size_t total_entries = (fbs_nodes != nullptr ? fbs_nodes->size() : 0U) +
(fbs_node_edges != nullptr ? fbs_node_edges->size() : 0U);
constexpr size_t kMinSlotCap = 1024; // allow small models without penalty
constexpr size_t kAbsoluteSlotCap = 10000000; // ~80 MB of unique_ptr<Node> on 64-bit
const size_t slot_cap = std::min(kAbsoluteSlotCap, std::max(kMinSlotCap, total_entries * 64U));

Check warning on line 6660 in onnxruntime/core/graph/graph.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/graph/graph.cc:6660: Add #include <algorithm> for min [build/include_what_you_use] [4]
ORT_RETURN_IF(required_node_slot_count > slot_cap,
"Node index ", required_node_slot_count - 1,
" is unreasonably large relative to the number of entries (",
total_entries, "). Invalid ORT format model.");

ORT_RETURN_IF(fbs_graph.max_node_index() < required_node_slot_count,
"Serialized max node index is smaller than the required node slot count. Invalid ORT format model.");
nodes_.resize(required_node_slot_count);
Comment thread
tianleiwu marked this conversation as resolved.

// It is possible to have no nodes in the model. Most likely scenario is the subgraph of an If Node
// where the subgraph returns a Constant node. The Constant node will be lifted to an initializer by ORT
Expand All @@ -6617,18 +6675,22 @@
ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model.");
std::unique_ptr<Node> node;
ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, load_options, logger_, node));
ORT_RETURN_IF(node->Index() >= fbs_graph.max_node_index(), "Node index is out of range");
ORT_RETURN_IF(node->Index() >= nodes_.size(), "Node index is out of range");
ORT_RETURN_IF(nodes_[node->Index()] != nullptr,
"Duplicate node index ", node->Index(), ". Invalid ORT format model.");
nodes_[node->Index()] = std::move(node);
++num_of_nodes_;
}
}

// NodeEdges
auto* fbs_node_edges = fbs_graph.node_edges();
if (fbs_node_edges != nullptr) {
for (const auto* fbs_node_edge : *fbs_node_edges) {
ORT_RETURN_IF(nullptr == fbs_node_edge, "NodeEdge is missing. Invalid ORT format model.");
ORT_RETURN_IF(fbs_node_edge->node_index() >= fbs_graph.max_node_index(), "Node index is out of range");
ORT_RETURN_IF(fbs_node_edge->node_index() >= nodes_.size(), "Node index is out of range");
ORT_RETURN_IF(nodes_[fbs_node_edge->node_index()] == nullptr,
"NodeEdge references missing node ", fbs_node_edge->node_index(),
". Invalid ORT format model.");
ORT_RETURN_IF_ERROR(nodes_[fbs_node_edge->node_index()]->LoadEdgesFromOrtFormat(*fbs_node_edge, *this));
Comment thread
tianleiwu marked this conversation as resolved.
}
}
Expand All @@ -6640,7 +6702,9 @@
node_args.reserve(fbs_node_args->size());
for (const auto* fbs_node_arg_name : *fbs_node_args) {
ORT_RETURN_IF(nullptr == fbs_node_arg_name, "NodeArg Name is missing. Invalid ORT format model.");
gsl::not_null<NodeArg*> node_arg = GetNodeArg(fbs_node_arg_name->str());
auto* node_arg = GetNodeArg(fbs_node_arg_name->str());
ORT_RETURN_IF(node_arg == nullptr, "Graph references unknown NodeArg '", fbs_node_arg_name->str(),
"'. Invalid ORT format model.");
node_args.push_back(node_arg);
}
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/graph/graph_flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,16 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
} else {
const auto* fbs_raw_data = fbs_tensor.raw_data();
if (fbs_raw_data) {
size_t expected_num_bytes = 0;
ORT_RETURN_IF_ERROR(GetSizeInBytesFromFbsTensor(fbs_tensor, expected_num_bytes));
const auto* fbs_name = fbs_tensor.name();
const char* tensor_name = fbs_name ? fbs_name->c_str() : "<unnamed>";
ORT_RETURN_IF(
fbs_raw_data->size() != expected_num_bytes,
"Initializer raw data size mismatch for tensor '", tensor_name,
"'. Expected ", expected_num_bytes, " bytes but found ", fbs_raw_data->size(),
". Invalid ORT format model.");

if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
const void* data_offset = fbs_raw_data->Data();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,

// Load the model metadata
if (const auto* fbs_metadata_props = fbs_model.metadata_props()) {
ORT_RETURN_IF_ERROR(fbs::utils::ValidateRequiredTableOffsets(fbs_metadata_props, "metadata property"));
model->model_metadata_.reserve(fbs_metadata_props->size());
for (const auto* prop : *fbs_metadata_props) {
ORT_RETURN_IF(nullptr == prop, "Null entry in metadata_props. Invalid ORT format model.");
Expand Down
131 changes: 131 additions & 0 deletions onnxruntime/test/framework/ort_model_only_test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <algorithm>

#include "core/flatbuffers/ort_format_version.h"
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h"
Expand Down Expand Up @@ -41,6 +44,46 @@ struct OrtModelTestInfo {
TransformerLevel optimization_level = TransformerLevel::Level3;
};

namespace {

flatbuffers::Offset<fbs::TypeInfo> CreateFloatTensorTypeInfo(flatbuffers::FlatBufferBuilder& builder,
int64_t dim_value) {
const auto dim_value_off =
fbs::CreateDimensionValue(builder, fbs::DimensionValueType::VALUE, dim_value);
std::vector<flatbuffers::Offset<fbs::Dimension>> dims{
fbs::CreateDimension(builder, dim_value_off)};
const auto shape = fbs::CreateShapeDirect(builder, &dims);
const auto tensor_type =
fbs::CreateTensorTypeAndShape(builder, fbs::TensorDataType::FLOAT, shape);
return fbs::CreateTypeInfoDirect(builder, nullptr, fbs::TypeInfoValue::tensor_type, tensor_type.Union());
}

std::vector<uint8_t> BuildOrtModelBuffer(
const std::function<flatbuffers::Offset<fbs::Graph>(flatbuffers::FlatBufferBuilder&)>& create_graph) {
flatbuffers::FlatBufferBuilder builder;

const auto graph = create_graph(builder);
std::vector<flatbuffers::Offset<fbs::OperatorSetId>> opset_imports{
fbs::CreateOperatorSetIdDirect(builder, "", 18)};
const auto model = fbs::CreateModelDirect(builder, 8, &opset_imports, "ort-model-test", "1", "",
1, "", graph, "");
const auto session = fbs::CreateInferenceSessionDirect(builder,
std::to_string(kOrtModelVersion).c_str(), model);
fbs::FinishInferenceSessionBuffer(builder, session);

return std::vector<uint8_t>(builder.GetBufferPointer(), builder.GetBufferPointer() + builder.GetSize());
}

Status LoadOrtBuffer(const std::vector<uint8_t>& buffer) {
SessionOptions so;
ORT_RETURN_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));

InferenceSessionWrapper session_object{so, GetEnvironment()};
return session_object.Load(buffer.data(), static_cast<int>(buffer.size()));
}

} // namespace

static void RunOrtModel(const OrtModelTestInfo& test_info) {
SessionOptions so;
so.session_logid = test_info.logid;
Expand Down Expand Up @@ -85,6 +128,94 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) {
test_info.output_verifier(fetches);
}

TEST(OrtModelTest, RejectsInitializerRawDataSizeMismatch) {
const auto buffer = BuildOrtModelBuffer([](flatbuffers::FlatBufferBuilder& builder) {
std::vector<int64_t> dims{1};
std::vector<uint8_t> raw_data(sizeof(float) * 2, 0);
std::vector<flatbuffers::Offset<fbs::Tensor>> initializers{
fbs::CreateTensorDirect(builder, "bad_initializer", "", &dims, fbs::TensorDataType::FLOAT, &raw_data)};
return fbs::CreateGraphDirect(builder, &initializers);
});

const auto status = LoadOrtBuffer(buffer);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("raw data size mismatch"));
}

TEST(OrtModelTest, RejectsNullNodeArgTableEntry) {
auto buffer = BuildOrtModelBuffer([](flatbuffers::FlatBufferBuilder& builder) {
std::vector<flatbuffers::Offset<fbs::ValueInfo>> node_args{
fbs::CreateValueInfoDirect(builder, "X", "", CreateFloatTensorTypeInfo(builder, 1))};
return fbs::CreateGraphDirect(builder, nullptr, &node_args);
});

const auto* fbs_session = fbs::GetInferenceSession(buffer.data());
ASSERT_NE(fbs_session, nullptr);
const auto* fbs_node_args = fbs_session->model()->graph()->node_args();
ASSERT_NE(fbs_node_args, nullptr);

auto* raw_offsets = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(fbs_node_args->Data()));
std::fill_n(raw_offsets, sizeof(flatbuffers::uoffset_t), 0);

const auto status = LoadOrtBuffer(buffer);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(),
testing::AnyOf(testing::HasSubstr("Null node arg entry"),
testing::HasSubstr("verification failed")));
}

TEST(OrtModelTest, RejectsDanglingNodeEdge) {
const auto buffer = BuildOrtModelBuffer([](flatbuffers::FlatBufferBuilder& builder) {
std::vector<flatbuffers::Offset<fbs::NodeEdge>> node_edges{
fbs::CreateNodeEdgeDirect(builder, 0)};
return fbs::CreateGraphDirect(builder, nullptr, nullptr, nullptr, 1, &node_edges);
});

const auto status = LoadOrtBuffer(buffer);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("references missing node"));
}
Comment thread
tianleiwu marked this conversation as resolved.

TEST(OrtModelTest, RejectsAdversarialLargeNodeIndex) {
// A single node with a huge index should be rejected to prevent memory amplification.
const auto buffer = BuildOrtModelBuffer([](flatbuffers::FlatBufferBuilder& builder) {
const uint32_t huge_index = 100'000'000;
std::vector<flatbuffers::Offset<fbs::Node>> nodes{
fbs::CreateNodeDirect(builder, "n", "", "", 1, huge_index, "Identity")};
return fbs::CreateGraphDirect(builder, nullptr, nullptr, &nodes, huge_index + 1);
});

const auto status = LoadOrtBuffer(buffer);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("unreasonably large"));
}

TEST(OrtModelTest, RejectsInvalidEdgeEndNodeIndex) {
// An EdgeEnd referencing a non-existent node should be rejected gracefully
// rather than crashing via ORT_ENFORCE or nullptr dereference.
const auto buffer = BuildOrtModelBuffer([](flatbuffers::FlatBufferBuilder& builder) {
// Create a valid node at index 0 with empty inputs/outputs so it passes node-loading validation.
std::vector<flatbuffers::Offset<flatbuffers::String>> empty_args;
std::vector<int32_t> empty_arg_counts;
std::vector<flatbuffers::Offset<fbs::Node>> nodes{
fbs::CreateNodeDirect(builder, "n0", "", "", 1, 0, "Identity",
fbs::NodeType::Primitive, nullptr,
&empty_args, &empty_args, nullptr,
&empty_arg_counts, &empty_args)};
// Create a NodeEdge for node 0 with an input edge referencing non-existent node 99
std::vector<fbs::EdgeEnd> input_edges{fbs::EdgeEnd(99, 0, 0)};
std::vector<flatbuffers::Offset<fbs::NodeEdge>> node_edges{
fbs::CreateNodeEdgeDirect(builder, 0, &input_edges)};
return fbs::CreateGraphDirect(builder, nullptr, nullptr, &nodes, 100, &node_edges);
});

const auto status = LoadOrtBuffer(buffer);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(),
testing::AnyOf(testing::HasSubstr("out-of-range node index"),
testing::HasSubstr("references missing node")));
}

#if !defined(ORT_MINIMAL_BUILD)
// Keep the CompareTypeProtos in case we need debug the difference
/*
Expand Down
Loading