From 41fa7ab0e037a12d5337a70d8ef2bdd5b6a0a1af Mon Sep 17 00:00:00 2001 From: Aditya Date: Mon, 9 Mar 2026 12:50:08 +0530 Subject: [PATCH] [tmva [sofie] Add ONNX Gelu operator support Implement the Gelu operator for the SOFIE inference engine with ONNX Opset 20 specifications. Port of root-project/root#20876. The implementation uses the exact Gelu formula: y = 0.5 * x * (1+ erf( x / sqrt(2))) Uses compile-time hexfloat constants for 1/sqrt(2) and 0.5 to ensure bit-exact reproducibility. Updates RModelParser_ONNX to register the Gelu operator node. It explicitly rejects the approximate attribute (tanh approximation) as it is not yet supported. --- src/SOFIE_core/CMakeLists.txt | 1 + src/SOFIE_core/inc/SOFIE/ROperator_Gelu.hxx | 74 +++++++++++++++++++++ src/SOFIE_parsers/CMakeLists.txt | 1 + src/SOFIE_parsers/src/ParseGelu.cxx | 45 +++++++++++++ src/SOFIE_parsers/src/RModelParser_ONNX.cxx | 2 + 5 files changed, 123 insertions(+) create mode 100644 src/SOFIE_core/inc/SOFIE/ROperator_Gelu.hxx create mode 100644 src/SOFIE_parsers/src/ParseGelu.cxx diff --git a/src/SOFIE_core/CMakeLists.txt b/src/SOFIE_core/CMakeLists.txt index 84a6658..201b48d 100644 --- a/src/SOFIE_core/CMakeLists.txt +++ b/src/SOFIE_core/CMakeLists.txt @@ -47,6 +47,7 @@ set(source_headers SOFIE/ROperator_Erf.hxx SOFIE/ROperator_Swish.hxx SOFIE/ROperator_Elu.hxx + SOFIE/ROperator_Gelu.hxx SOFIE/ROperator_Comparision.hxx SOFIE/ROperator_EyeLike.hxx SOFIE/ROperator_Range.hxx diff --git a/src/SOFIE_core/inc/SOFIE/ROperator_Gelu.hxx b/src/SOFIE_core/inc/SOFIE/ROperator_Gelu.hxx new file mode 100644 index 0000000..0828670 --- /dev/null +++ b/src/SOFIE_core/inc/SOFIE/ROperator_Gelu.hxx @@ -0,0 +1,74 @@ +#ifndef SOFIE_ROPERATOR_GELU +#define SOFIE_ROPERATOR_GELU + +#include "SOFIE_common.hxx" +#include "ROperator.hxx" +#include "RModel.hxx" + +#include + +namespace SOFIE{ + +template +class ROperator_GELU final : public ROperator +{ + +private: + + std::string fNX; + std::string fNY; + std::vector fShape; + +public: + ROperator_GELU(){} + ROperator_GELU(std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ + fInputTensorNames = { fNX }; + fOutputTensorNames = { fNY }; + } + + std::vector TypeInference(std::vector input) override { + return input; + } + + std::vector> ShapeInference(std::vector> input) override { + return input; + } + + void Initialize(RModel& model) override { + //input must be a graph input, or already initialized intermediate tensor + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error("TMVA SOFIE GELU Op Input Tensor " + fNX + " is not found in model"); + } + fShape = model.GetTensorShape(fNX); + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + } + + std::string Generate(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()){ + throw std::runtime_error("TMVA SOFIE GELU operator called to Generate without being initialized first"); + } + std::stringstream out; + size_t length = ConvertShapeToLength(fShape); + + // GELU exact formula: y = 0.5 * x * (1 + erf(x / sqrt(2))) + // Using hexfloat for compile-time precision: + // 0x1.6a09e667f3bcdp-1 = 1/sqrt(2) = 0.7071067811865476 (exact to be precise) + // 0x1.0000000000000p-1 = 0.5 + + out << "\n//------ GELU\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP << "tensor_" << fNY << "[id] = 0x1.0000000000000p-1 * tensor_" << fNX + << "[id] * (1.0 + std::erf(tensor_" << fNX << "[id] * 0x1.6a09e667f3bcdp-1));\n"; + out << SP << "}\n"; + return out.str(); + } + + std::vector GetStdLibs() override { return { std::string("cmath") };} +}; + +}// namespace SOFIE + + +#endif //SOFIE_ROPERATOR_GELU \ No newline at end of file diff --git a/src/SOFIE_parsers/CMakeLists.txt b/src/SOFIE_parsers/CMakeLists.txt index 379b7d7..4cf13e2 100644 --- a/src/SOFIE_parsers/CMakeLists.txt +++ b/src/SOFIE_parsers/CMakeLists.txt @@ -62,6 +62,7 @@ set(sources_cxx src/ParseExpand.cxx src/ParseGather.cxx src/ParseElu.cxx + src/ParseGelu.cxx src/ParseFuseConvAdd.cxx src/ParseFuseConvTransposeAdd.cxx src/ParseFuseGemmRelu.cxx diff --git a/src/SOFIE_parsers/src/ParseGelu.cxx b/src/SOFIE_parsers/src/ParseGelu.cxx new file mode 100644 index 0000000..8b5ffdc --- /dev/null +++ b/src/SOFIE_parsers/src/ParseGelu.cxx @@ -0,0 +1,45 @@ +#include "SOFIE/RModelParser_ONNX.hxx" +#include "SOFIE/ROperator_Gelu.hxx" +#include "onnx_proto3.pb.h" + +namespace SOFIE { + +ParserFuncSignature ParseGelu = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + // Check for unsupported tanh approximation attribute + for (int_t i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "approximate") { + if (nodeproto.attribute(i).s() == "tanh") { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu tanh approximation not implemented"); + } + } + } + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor " + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + switch (input_type) { + case ETensorType::FLOAT: op.reset(new ROperator_GELU(input_name, output_name)); break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gelu does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE \ No newline at end of file diff --git a/src/SOFIE_parsers/src/RModelParser_ONNX.cxx b/src/SOFIE_parsers/src/RModelParser_ONNX.cxx index 68662ae..e384618 100644 --- a/src/SOFIE_parsers/src/RModelParser_ONNX.cxx +++ b/src/SOFIE_parsers/src/RModelParser_ONNX.cxx @@ -75,6 +75,7 @@ extern ParserFuncSignature ParseLayerNormalization; extern ParserFuncSignature ParseGather; extern ParserFuncSignature ParseErf; extern ParserFuncSignature ParseElu; +extern ParserFuncSignature ParseGelu; extern ParserFuncSignature ParseEyeLike; extern ParserFuncSignature ParseRange; extern ParserFuncSignature ParseTopK; @@ -219,6 +220,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Gather", ParseGather); RegisterOperator("Erf", ParseErf); RegisterOperator("Elu", ParseElu); + RegisterOperator("Gelu", ParseGelu); RegisterOperator("EyeLike", ParseEyeLike); RegisterOperator("Range", ParseRange); RegisterOperator("TopK", ParseTopK);