From 9da33ceebab33402f588ab6eb0f98eb8d5c642c5 Mon Sep 17 00:00:00 2001 From: Aditya Date: Mon, 9 Mar 2026 19:00:02 +0530 Subject: [PATCH] feat: Add HardSigmoid, HardSwish, and stabilized Softplus operators Port of root-project/root#20933, #20944, and #21092. - HardSigmoid: stateful operator with alpha/beta attributes (ONNX Opset 6+) - HardSwish: inline generation with split topology (ONNX Opset 14) - Softplus: numerical stability fix using log1p and threshold at 20.0f All use hexfloat constants for bit-exact reproducibility. --- src/SOFIE_core/CMakeLists.txt | 2 + .../inc/SOFIE/ROperator_BasicUnary.hxx | 12 +++- .../inc/SOFIE/ROperator_HardSigmoid.hxx | 70 +++++++++++++++++++ .../inc/SOFIE/ROperator_HardSwish.hxx | 70 +++++++++++++++++++ src/SOFIE_parsers/CMakeLists.txt | 2 + src/SOFIE_parsers/src/ParseBasicUnary.cxx | 5 ++ src/SOFIE_parsers/src/ParseHardSigmoid.cxx | 47 +++++++++++++ src/SOFIE_parsers/src/ParseHardSwish.cxx | 35 ++++++++++ src/SOFIE_parsers/src/RModelParser_ONNX.cxx | 6 ++ 9 files changed, 247 insertions(+), 2 deletions(-) create mode 100644 src/SOFIE_core/inc/SOFIE/ROperator_HardSigmoid.hxx create mode 100644 src/SOFIE_core/inc/SOFIE/ROperator_HardSwish.hxx create mode 100644 src/SOFIE_parsers/src/ParseHardSigmoid.cxx create mode 100644 src/SOFIE_parsers/src/ParseHardSwish.cxx diff --git a/src/SOFIE_core/CMakeLists.txt b/src/SOFIE_core/CMakeLists.txt index 84a6658..4e6310d 100644 --- a/src/SOFIE_core/CMakeLists.txt +++ b/src/SOFIE_core/CMakeLists.txt @@ -47,6 +47,8 @@ set(source_headers SOFIE/ROperator_Erf.hxx SOFIE/ROperator_Swish.hxx SOFIE/ROperator_Elu.hxx + SOFIE/ROperator_HardSigmoid.hxx + SOFIE/ROperator_HardSwish.hxx SOFIE/ROperator_Comparision.hxx SOFIE/ROperator_EyeLike.hxx SOFIE/ROperator_Range.hxx diff --git a/src/SOFIE_core/inc/SOFIE/ROperator_BasicUnary.hxx b/src/SOFIE_core/inc/SOFIE/ROperator_BasicUnary.hxx index c18c17e..e6d4a25 100644 --- a/src/SOFIE_core/inc/SOFIE/ROperator_BasicUnary.hxx +++ b/src/SOFIE_core/inc/SOFIE/ROperator_BasicUnary.hxx @@ -8,7 +8,7 @@ namespace SOFIE { -enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs }; +enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs, kSoftplus }; template struct UnaryOpTraits { @@ -62,6 +62,14 @@ struct UnaryOpTraits { static std::string Op(const std::string &X) { return "std::abs(" + X + ")"; } }; +template +struct UnaryOpTraits { + static std::string Name() { return "Softplus"; } + static std::string Op(const std::string &X) { + return "((" + X + " >= 0x1.4000000000000p+4f) ? " + X + " : std::log1p(std::exp(" + X + ")))"; + } +}; + template class ROperator_BasicUnary final : public ROperator { private: @@ -108,7 +116,7 @@ public: } std::vector GetStdLibs() override { - if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) { + if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog || Op == EBasicUnaryOperator::kSoftplus) { { return { std::string("cmath") }; } else { return {}; diff --git a/src/SOFIE_core/inc/SOFIE/ROperator_HardSigmoid.hxx b/src/SOFIE_core/inc/SOFIE/ROperator_HardSigmoid.hxx new file mode 100644 index 0000000..4df403d --- /dev/null +++ b/src/SOFIE_core/inc/SOFIE/ROperator_HardSigmoid.hxx @@ -0,0 +1,70 @@ +#ifndef SOFIE_ROPERATOR_HARDSIGMOID +#define SOFIE_ROPERATOR_HARDSIGMOID + +#include +#include +#include + +#include + +namespace SOFIE { + +template +class ROperator_HardSigmoid final : public ROperator +{ + +private: + + std::string fNX; + std::string fNY; + std::vector fShape; + float fAlpha; + float fBeta; + +public: + ROperator_HardSigmoid(){} + ROperator_HardSigmoid(std::string nameX, std::string nameY, float alpha, float beta): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), fAlpha(alpha), fBeta(beta){ + fInputTensorNames = { fNX }; + fOutputTensorNames = { fNY }; + } + + std::vector TypeInference(std::vector input) override { + return input; + } + + std::vector> ShapeInference(std::vector> input) override { + return input; + } + + void Initialize(RModel& model) override { + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error("SOFIE HardSigmoid Op Input Tensor " + fNX + " is not found in model"); + } + fShape = model.GetTensorShape(fNX); + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + } + + std::string Generate(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()){ + throw std::runtime_error("SOFIE HardSigmoid operator called to Generate without being initialized first"); + } + std::stringstream out; + size_t length = ConvertShapeToLength(fShape); + + // HardSigmoid: y = max(0, min(1, alpha * x + beta)) + out << "\n//------ HardSigmoid\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP << "tensor_" << fNY << "[id] = std::fmax(0x0p+0f, std::fmin(0x1p+0f, " + << fAlpha << "f * tensor_" << fNX << "[id] + " << fBeta << "f));\n"; + out << SP << "}\n"; + return out.str(); + } + + std::vector GetStdLibs() override { return { std::string("cmath") };} +}; + +} // namespace SOFIE + +#endif \ No newline at end of file diff --git a/src/SOFIE_core/inc/SOFIE/ROperator_HardSwish.hxx b/src/SOFIE_core/inc/SOFIE/ROperator_HardSwish.hxx new file mode 100644 index 0000000..4c83c9a --- /dev/null +++ b/src/SOFIE_core/inc/SOFIE/ROperator_HardSwish.hxx @@ -0,0 +1,70 @@ +#ifndef SOFIE_ROPERATOR_HARDSWISH +#define SOFIE_ROPERATOR_HARDSWISH + +#include +#include +#include + +#include + +namespace SOFIE { + +template +class ROperator_HardSwish final : public ROperator +{ + +private: + + std::string fNX; + std::string fNY; + std::vector fShape; + +public: + ROperator_HardSwish(){} + ROperator_HardSwish(std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ + fInputTensorNames = { fNX }; + fOutputTensorNames = { fNY }; + } + + std::vector TypeInference(std::vector input) override { + return input; + } + + std::vector> ShapeInference(std::vector> input) override { + return input; + } + + void Initialize(RModel& model) override { + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error("SOFIE HardSwish Op Input Tensor " + fNX + " is not found in model"); + } + fShape = model.GetTensorShape(fNX); + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + } + + std::string Generate(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()){ + throw std::runtime_error("SOFIE HardSwish operator called to Generate without being initialized first"); + } + std::stringstream out; + size_t length = ConvertShapeToLength(fShape); + + // HardSwish: y = x * max(0, min(1, x/6 + 0.5)) + // Split topology for debuggability + out << "\n//------ HardSwish\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP << "float h = 0x1.5555555555555p-3f * tensor_" << fNX << "[id] + 0x1p-1f;\n"; + out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX + << "[id] * std::fmax(0x0p+0f, std::fmin(0x1p+0f, h));\n"; + out << SP << "}\n"; + return out.str(); + } + + std::vector GetStdLibs() override { return { std::string("cmath") };} +}; + +} // namespace SOFIE + +#endif \ No newline at end of file diff --git a/src/SOFIE_parsers/CMakeLists.txt b/src/SOFIE_parsers/CMakeLists.txt index 379b7d7..8a7d5d3 100644 --- a/src/SOFIE_parsers/CMakeLists.txt +++ b/src/SOFIE_parsers/CMakeLists.txt @@ -62,6 +62,8 @@ set(sources_cxx src/ParseExpand.cxx src/ParseGather.cxx src/ParseElu.cxx + src/ParseHardSigmoid.cxx + src/ParseHardSwish.cxx src/ParseFuseConvAdd.cxx src/ParseFuseConvTransposeAdd.cxx src/ParseFuseGemmRelu.cxx diff --git a/src/SOFIE_parsers/src/ParseBasicUnary.cxx b/src/SOFIE_parsers/src/ParseBasicUnary.cxx index 1470f26..40f5822 100644 --- a/src/SOFIE_parsers/src/ParseBasicUnary.cxx +++ b/src/SOFIE_parsers/src/ParseBasicUnary.cxx @@ -79,5 +79,10 @@ ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodePro return ParseBasicUnary(parser, nodeproto); }; +// Parse Softplus +ParserFuncSignature ParseSoftplus = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + return ParseBasicUnary(parser, nodeproto); +}; + } // namespace SOFIE diff --git a/src/SOFIE_parsers/src/ParseHardSigmoid.cxx b/src/SOFIE_parsers/src/ParseHardSigmoid.cxx new file mode 100644 index 0000000..625e496 --- /dev/null +++ b/src/SOFIE_parsers/src/ParseHardSigmoid.cxx @@ -0,0 +1,47 @@ +#include "SOFIE/RModelParser_ONNX.hxx" +#include "SOFIE/ROperator_HardSigmoid.hxx" +#include "onnx_proto3.pb.h" + +namespace SOFIE { + +ParserFuncSignature ParseHardSigmoid = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + // ONNX spec defaults: alpha=0.2, beta=0.5 + float alpha = 0.2f; + float beta = 0.5f; + + for (int_t i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "alpha") + alpha = nodeproto.attribute(i).f(); + else if (attribute_name == "beta") + beta = nodeproto.attribute(i).f(); + } + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser HardSigmoid op has input tensor " + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + switch (input_type) { + case ETensorType::FLOAT: op.reset(new ROperator_HardSigmoid(input_name, output_name, alpha, beta)); break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator HardSigmoid does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE \ No newline at end of file diff --git a/src/SOFIE_parsers/src/ParseHardSwish.cxx b/src/SOFIE_parsers/src/ParseHardSwish.cxx new file mode 100644 index 0000000..21fc398 --- /dev/null +++ b/src/SOFIE_parsers/src/ParseHardSwish.cxx @@ -0,0 +1,35 @@ +#include "SOFIE/RModelParser_ONNX.hxx" +#include "SOFIE/ROperator_HardSwish.hxx" +#include "onnx_proto3.pb.h" + +namespace SOFIE { + +ParserFuncSignature ParseHardSwish = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser HardSwish op has input tensor " + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + switch (input_type) { + case ETensorType::FLOAT: op.reset(new ROperator_HardSwish(input_name, output_name)); break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator HardSwish does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE \ No newline at end of file diff --git a/src/SOFIE_parsers/src/RModelParser_ONNX.cxx b/src/SOFIE_parsers/src/RModelParser_ONNX.cxx index 68662ae..dfbb776 100644 --- a/src/SOFIE_parsers/src/RModelParser_ONNX.cxx +++ b/src/SOFIE_parsers/src/RModelParser_ONNX.cxx @@ -75,6 +75,9 @@ extern ParserFuncSignature ParseLayerNormalization; extern ParserFuncSignature ParseGather; extern ParserFuncSignature ParseErf; extern ParserFuncSignature ParseElu; +extern ParserFuncSignature ParseHardSigmoid; +extern ParserFuncSignature ParseHardSwish; +extern ParserFuncSignature ParseSoftplus; extern ParserFuncSignature ParseEyeLike; extern ParserFuncSignature ParseRange; extern ParserFuncSignature ParseTopK; @@ -219,6 +222,9 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Gather", ParseGather); RegisterOperator("Erf", ParseErf); RegisterOperator("Elu", ParseElu); + RegisterOperator("HardSigmoid", ParseHardSigmoid); + RegisterOperator("HardSwish", ParseHardSwish); + RegisterOperator("Softplus", ParseSoftplus); RegisterOperator("EyeLike", ParseEyeLike); RegisterOperator("Range", ParseRange); RegisterOperator("TopK", ParseTopK);