diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 116d6e03e95..1b9b8a8150b 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -346,7 +346,18 @@ void Fusion::removeExpr(Expr* expr) { } } - ir_container()->removeExpr(expr); + auto* c = ir_container(); + auto expr_in_deque = std::find_if( + c->exprs_up_.begin(), + c->exprs_up_.end(), + [expr](std::unique_ptr& expr_up) { + return expr_up.get() == expr; + }); + NVF_ERROR( + expr_in_deque != c->exprs_up_.end(), + "Wanted to remove an expression but its unique ptr is missing."); + c->exprs_.erase(expr); + c->exprs_up_.erase(expr_in_deque); } void Fusion::removeVal(Val* val) { @@ -396,7 +407,17 @@ void Fusion::removeVal(Val* val) { for (auto e : exprs_to_remove) { removeExpr(e); } - ir_container()->removeVal(val); + + auto* c = ir_container(); + auto val_in_deque = std::find_if( + c->vals_up_.begin(), + c->vals_up_.end(), + [val](std::unique_ptr& val_up) { return val_up.get() == val; }); + NVF_ERROR( + val_in_deque != c->vals_up_.end(), + "Wanted to remove a value but its unique ptr is missing."); + c->vals_.erase(val); + c->vals_up_.erase(val_in_deque); invalidateTvsAndUses(); } @@ -910,7 +931,10 @@ void Fusion::registerVal(Val* val) { val->fusion() == this, val, " was not found in the active fusion."); } - ir_container()->registerVal(val); + auto* c = ir_container(); + c->vals_up_.emplace_back(val); + c->vals_.insert(val); + val->setName(IrContainerPasskey(), c->getValName(val->vtype())); } void Fusion::registerExpr(Expr* expr) { @@ -923,7 +947,10 @@ void Fusion::registerExpr(Expr* expr) { expr->fusion() == this, expr, " was not found in the active fusion."); } - ir_container()->registerExpr(expr); + auto* c = ir_container(); + c->exprs_up_.emplace_back(expr); + c->exprs_.insert(expr); + expr->setName(IrContainerPasskey(), c->getExprName()); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 5cd5f6ca36f..f487a26fbcf 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -115,66 +115,6 @@ IrContainer::~IrContainer() { clear(); } -void IrContainer::removeExpr(Expr* expr) { - NVF_ERROR( - exprs_.find(expr) != exprs_.end(), - "Wanted to remove an expression but it doesn't exist in this container."); - auto expr_in_deque = std::find_if( - exprs_up_.begin(), - exprs_up_.end(), - [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); - - NVF_ERROR( - expr_in_deque != exprs_up_.end(), - "Wanted to remove an expression but its unique ptr is missing."); - - exprs_.erase(expr); - exprs_up_.erase(expr_in_deque); -} - -//! Completely remove val from the fusion, break all dependencies associated -//! with it -void IrContainer::removeVal(Val* val) { - NVF_ERROR( - vals_.find(val) != vals_.end(), - "Wanted to remove a value but it doesn't exist in this container."); - auto val_in_deque = std::find_if( - vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr& val_up) { - return val_up.get() == val; - }); - - NVF_ERROR( - val_in_deque != vals_up_.end(), - "Wanted to remove a value but its unique ptr is missing."); - - vals_.erase(val); - vals_up_.erase(val_in_deque); -} - -//! Register the Val with this container -void IrContainer::registerVal(Val* val) { - if (inContainer(val)) { - return; - } - - // Otherwise handle registration locally - vals_up_.emplace_back(val); - vals_.insert(val); - val->setName(IrContainerPasskey(), getValName(val->vtype())); -} - -//! Register expr with this container. -void IrContainer::registerExpr(Expr* expr) { - if (inContainer(expr)) { - return; - } - - // Otherwise handle registration locally - exprs_up_.emplace_back(expr); - exprs_.insert(expr); - expr->setName(IrContainerPasskey(), getExprName()); -} - void IrContainer::clear() noexcept { FUSER_PERF_SCOPE("IrContainer clear"); vals_.clear(); diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 899f2f26439..a61f78be5f2 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -21,6 +21,7 @@ namespace nvfuser { // Passkey for container to register names with statements class IrContainerPasskey { friend class IrContainer; + friend class Fusion; private: explicit IrContainerPasskey() = default; @@ -92,18 +93,6 @@ class IrContainer { // Let Fusion access IrContainer::clear() friend class Fusion; - void removeExpr(Expr* expr); - - //! Completely remove val from the fusion, break all dependencies associated - //! with it - void removeVal(Val* val); - - //! Register the Val with this container - NVF_API void registerVal(Val* val); - - //! Register expr with this container. - NVF_API void registerExpr(Expr* expr); - StmtNameType getValName(ValType vtype) { if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { val_type_name_map_[vtype] = 0; @@ -117,8 +106,6 @@ class IrContainer { void clear() noexcept; - friend class StatementGuard; - // Deque of unique pointer is the memory owning data structure std::deque> vals_up_;