Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,18 @@ void Fusion::removeExpr(Expr* expr) {
}
}

ir_container()->removeExpr(expr);
auto* c = ir_container();
auto expr_in_deque = std::find_if(
c->exprs_up_.begin(),
c->exprs_up_.end(),
[expr](std::unique_ptr<Expr>& expr_up) {
return expr_up.get() == expr;
});
NVF_ERROR(
expr_in_deque != c->exprs_up_.end(),
"Wanted to remove an expression but its unique ptr is missing.");
c->exprs_.erase(expr);
c->exprs_up_.erase(expr_in_deque);
}

void Fusion::removeVal(Val* val) {
Expand Down Expand Up @@ -396,7 +407,17 @@ void Fusion::removeVal(Val* val) {
for (auto e : exprs_to_remove) {
removeExpr(e);
}
ir_container()->removeVal(val);

auto* c = ir_container();
auto val_in_deque = std::find_if(
c->vals_up_.begin(),
c->vals_up_.end(),
[val](std::unique_ptr<Val>& val_up) { return val_up.get() == val; });
NVF_ERROR(
val_in_deque != c->vals_up_.end(),
"Wanted to remove a value but its unique ptr is missing.");
c->vals_.erase(val);
c->vals_up_.erase(val_in_deque);

invalidateTvsAndUses();
}
Expand Down Expand Up @@ -910,7 +931,10 @@ void Fusion::registerVal(Val* val) {
val->fusion() == this, val, " was not found in the active fusion.");
}

ir_container()->registerVal(val);
auto* c = ir_container();
c->vals_up_.emplace_back(val);
c->vals_.insert(val);
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -923,7 +947,10 @@ void Fusion::registerExpr(Expr* expr) {
expr->fusion() == this, expr, " was not found in the active fusion.");
}

ir_container()->registerExpr(expr);
auto* c = ir_container();
c->exprs_up_.emplace_back(expr);
c->exprs_.insert(expr);
expr->setName(IrContainerPasskey(), c->getExprName());

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand Down
60 changes: 0 additions & 60 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,66 +115,6 @@ IrContainer::~IrContainer() {
clear();
}

void IrContainer::removeExpr(Expr* expr) {
NVF_ERROR(
exprs_.find(expr) != exprs_.end(),
"Wanted to remove an expression but it doesn't exist in this container.");
auto expr_in_deque = std::find_if(
exprs_up_.begin(),
exprs_up_.end(),
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });

NVF_ERROR(
expr_in_deque != exprs_up_.end(),
"Wanted to remove an expression but its unique ptr is missing.");

exprs_.erase(expr);
exprs_up_.erase(expr_in_deque);
}

//! Completely remove val from the fusion, break all dependencies associated
//! with it
void IrContainer::removeVal(Val* val) {
NVF_ERROR(
vals_.find(val) != vals_.end(),
"Wanted to remove a value but it doesn't exist in this container.");
auto val_in_deque = std::find_if(
vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr<Val>& val_up) {
return val_up.get() == val;
});

NVF_ERROR(
val_in_deque != vals_up_.end(),
"Wanted to remove a value but its unique ptr is missing.");

vals_.erase(val);
vals_up_.erase(val_in_deque);
}

//! Register the Val with this container
void IrContainer::registerVal(Val* val) {
if (inContainer(val)) {
return;
}

// Otherwise handle registration locally
vals_up_.emplace_back(val);
vals_.insert(val);
val->setName(IrContainerPasskey(), getValName(val->vtype()));
}

//! Register expr with this container.
void IrContainer::registerExpr(Expr* expr) {
if (inContainer(expr)) {
return;
}

// Otherwise handle registration locally
exprs_up_.emplace_back(expr);
exprs_.insert(expr);
expr->setName(IrContainerPasskey(), getExprName());
}

void IrContainer::clear() noexcept {
FUSER_PERF_SCOPE("IrContainer clear");
vals_.clear();
Expand Down
15 changes: 1 addition & 14 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace nvfuser {
// Passkey for container to register names with statements
class IrContainerPasskey {
friend class IrContainer;
friend class Fusion;

private:
explicit IrContainerPasskey() = default;
Expand Down Expand Up @@ -92,18 +93,6 @@ class IrContainer {
// Let Fusion access IrContainer::clear()
friend class Fusion;

void removeExpr(Expr* expr);

//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val);

//! Register the Val with this container
NVF_API void registerVal(Val* val);

//! Register expr with this container.
NVF_API void registerExpr(Expr* expr);

StmtNameType getValName(ValType vtype) {
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
val_type_name_map_[vtype] = 0;
Expand All @@ -117,8 +106,6 @@ class IrContainer {

void clear() noexcept;

friend class StatementGuard;

// Deque of unique pointer is the memory owning data structure
std::deque<std::unique_ptr<Val>> vals_up_;

Expand Down
Loading