diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index baf1de84614..1303ef97076 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -19,7 +20,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -137,6 +140,16 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { std::swap(a.outputs_, b.outputs_); std::swap(a.io_alias_, b.io_alias_); + + // Swap per-Fusion special values (Phase 2) + std::swap(a.zero_val_, b.zero_val_); + std::swap(a.one_val_, b.one_val_); + std::swap(a.true_val_, b.true_val_); + std::swap(a.false_val_, b.false_val_); + std::swap(a.magic_zero_val_, b.magic_zero_val_); + + std::swap(a.axioms_, b.axioms_); + std::swap(a.metadata_, b.metadata_); } std::unique_ptr Fusion::segment( @@ -198,6 +211,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + if (from->axioms_ != nullptr) { + to->axioms_ = std::make_unique>(); + to->axioms_->reserve(from->axioms_->size()); + for (auto pred : *from->axioms_) { + to->axioms_->push_back(ir_cloner.clone(pred)); + } + } + + for (auto& [key, val_expr] : from->metadata_) { + to->metadata_[ir_cloner.clone(key)] = std::make_pair( + ir_cloner.clone(val_expr.first), ir_cloner.clone(val_expr.second)); + } + if (from->all_tvs_ptr_ != nullptr) { to->all_tvs_ptr_ = std::make_unique>(); to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); @@ -264,6 +290,18 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); + // Reset per-Fusion special values (they'll be recreated lazily if needed). + // These unique_ptrs own the Val objects; ir_container()->clear() above only + // removed them from vals_ (they were already absent from vals_up_). + zero_val_.reset(); + one_val_.reset(); + true_val_.reset(); + false_val_.reset(); + magic_zero_val_.reset(); + + axioms_.reset(); + metadata_.clear(); + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -297,6 +335,13 @@ void Fusion::removeExpr(Expr* expr) { void Fusion::removeVal(Val* val) { assertInContainer(val, "Cannot remove val "); + // Don't remove cached special vals — they are lazily created singletons + if (val == zero_val_.get() || val == one_val_.get() || + val == true_val_.get() || val == false_val_.get() || + val == magic_zero_val_.get()) { + return; + } + NVF_CHECK( !val->isFusionInput(), "Cannot remove val as it is an input of the fusion."); @@ -689,6 +734,122 @@ void Fusion::printTransforms() { t_exprs.handle(this); } +Val* Fusion::zeroVal() { + if (!zero_val_) { + auto val = IrBuilder::createInContainer(this, 0L, DataType::Index); + NVF_ERROR(ir_container()->vals_up_.back().get() == val); + zero_val_ = std::unique_ptr(ir_container()->vals_up_.back().release()); + ir_container()->vals_up_.pop_back(); + } + return zero_val_.get(); +} + +Val* Fusion::oneVal() { + if (!one_val_) { + auto val = IrBuilder::createInContainer(this, 1L, DataType::Index); + NVF_ERROR(ir_container()->vals_up_.back().get() == val); + one_val_ = std::unique_ptr(ir_container()->vals_up_.back().release()); + ir_container()->vals_up_.pop_back(); + } + return one_val_.get(); +} + +Val* Fusion::falseVal() { + if (!false_val_) { + auto val = IrBuilder::createInContainer(this, false, DataType::Bool); + NVF_ERROR(ir_container()->vals_up_.back().get() == val); + false_val_ = + std::unique_ptr(ir_container()->vals_up_.back().release()); + ir_container()->vals_up_.pop_back(); + } + return false_val_.get(); +} + +Val* Fusion::trueVal() { + if (!true_val_) { + auto val = IrBuilder::createInContainer(this, true, DataType::Bool); + NVF_ERROR(ir_container()->vals_up_.back().get() == val); + true_val_ = std::unique_ptr(ir_container()->vals_up_.back().release()); + ir_container()->vals_up_.pop_back(); + } + return true_val_.get(); +} + +NamedScalar* Fusion::magicZeroVal() { + if (!magic_zero_val_) { + auto val = IrBuilder::createInContainer( + this, kMagicZeroName, DataType::Index); + NVF_ERROR(ir_container()->vals_up_.back().get() == val); + magic_zero_val_ = std::unique_ptr( + ir_container()->vals_up_.back().release()->as()); + ir_container()->vals_up_.pop_back(); + } + return magic_zero_val_.get(); +} + +Val* Fusion::zeroVal(DataType dtype) { + if (dtype == DataType::Index) { + return zeroVal(); + } else if (isBooleanType(dtype)) { + return falseVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 0L, dtype); + } +} + +Val* Fusion::oneVal(DataType dtype) { + if (dtype == DataType::Index) { + return oneVal(); + } else if (isBooleanType(dtype)) { + return trueVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 1L, dtype); + } +} + +Val* Fusion::metadataOf(Val* v) { + if (metadata_.count(v) == 0) { + auto metadata_val = + IrBuilder::createInContainer(this, metaDataTypeOf(v)); + auto metadata_expr = + IrBuilder::createInContainer(this, metadata_val, v); + metadata_[v] = std::make_pair(metadata_val, metadata_expr); + } + return metadata_.at(v).first; +} + +const std::vector& Fusion::axioms() { + if (!axioms_) { + axioms_ = std::make_unique>(); + axioms_->reserve(kParallelTypeThreads.size() * 3); + auto zero = zeroVal(); + for (auto p : kParallelTypeThreads) { + auto pidx = NamedScalar::getParallelIndex(p); + auto pdim = NamedScalar::getParallelDim(p); + axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); + axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); + axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); + } + } + return *axioms_; +} + +void Fusion::assumePositive(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); +} + +void Fusion::assumeNonNegative(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; diff --git a/csrc/fusion.h b/csrc/fusion.h index f02c1b0310d..ac3b3ae0686 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -59,6 +59,7 @@ namespace nvfuser { //! checks. class Fusion; +class NamedScalar; class TensorView; class SegmentCandidateFinder; @@ -549,55 +550,28 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container()->numExprs(); } - int64_t numVals(bool include_shortcuts) const noexcept { + // When include_shortcuts is true, count cached special vals (zeroVal, etc.) + // which live outside vals_up_ but inside vals_. + int64_t numVals(bool include_shortcuts = true) const noexcept { return ir_container()->numVals(include_shortcuts); } // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_container()->zeroVal(); - } - - Val* oneVal() { - return ir_container()->oneVal(); - } + Val* zeroVal(); + Val* oneVal(); + Val* falseVal(); + Val* trueVal(); + NamedScalar* magicZeroVal(); + Val* zeroVal(DataType dtype); + Val* oneVal(DataType dtype); - Val* falseVal() { - return ir_container()->falseVal(); - } - - Val* trueVal() { - return ir_container()->trueVal(); - } - - NamedScalar* magicZeroVal() { - return ir_container()->magicZeroVal(); - } - - Val* zeroVal(DataType dtype) { - return ir_container()->zeroVal(dtype); - } - - Val* oneVal(DataType dtype) { - return ir_container()->oneVal(dtype); - } - - Val* metadataOf(Val* val) { - return ir_container()->metadataOf(val); - } + Val* metadataOf(Val* val); // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_container()->axioms(); - } - - void assumePositive(Val* val) { - ir_container()->assumePositive(val); - } + const std::vector& axioms(); - void assumeNonNegative(Val* val) { - ir_container()->assumeNonNegative(val); - } + void assumePositive(Val* val); + void assumeNonNegative(Val* val); // Statement removal void removeStatementsCreatedAfter( @@ -667,6 +641,16 @@ class NVF_API Fusion : public PolymorphicBase { inline static const std::string exact_mappings_key = "exact_mappings"; std::unique_ptr ir_container_; + + std::unique_ptr zero_val_; + std::unique_ptr one_val_; + std::unique_ptr true_val_; + std::unique_ptr false_val_; + std::unique_ptr magic_zero_val_; + + std::unique_ptr> axioms_; + + std::unordered_map> metadata_; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 3c54966c87d..b50aff8a851 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -7,6 +7,7 @@ // clang-format on #include "ir/container.h" +#include "fusion.h" #include "instrumentation.h" #include "ir/base_nodes.h" #include "ir/builder.h" @@ -80,20 +81,12 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - std::swap(a.metadata_, b.metadata_); - std::swap(a.parent_, b.parent_); - - std::swap(a.zero_val_, b.zero_val_); - std::swap(a.one_val_, b.one_val_); - std::swap(a.true_val_, b.true_val_); - std::swap(a.false_val_, b.false_val_); - std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->clear(); + IrCloner ir_cloner(to->parent()); // Copy values in deterministic order @@ -115,15 +108,6 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; - if (from->axioms_ != nullptr) { - to->axioms_ = std::make_unique>(); - for (auto pred : *from->axioms_) { - to->axioms_->push_back(ir_cloner.clone(pred)); - } - } - - to->metadata_ = ir_cloner.clone(from->metadata_); - return ir_cloner; } @@ -153,13 +137,6 @@ void IrContainer::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it void IrContainer::removeVal(Val* val) { - // Don't remove shortcuts - if (val == true_val_.get() || val == false_val_.get() || - val == one_val_.get() || val == zero_val_.get() || - val == magic_zero_val_.get()) { - return; - } - NVF_ERROR( vals_.find(val) != vals_.end(), "Wanted to remove a value but it doesn't exist in this container."); @@ -206,9 +183,7 @@ void IrContainer::clear() noexcept { vals_up_.clear(); exprs_.clear(); exprs_up_.clear(); - axioms_.reset(); val_type_name_map_.clear(); - metadata_.clear(); expr_name_counter_ = 0; } @@ -244,123 +219,6 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } -// Shortcuts for frequently used vals -Val* IrContainer::zeroVal() { - if (!zero_val_) { - auto zero_val = - IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == zero_val); - zero_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return zero_val_.get(); -} - -Val* IrContainer::zeroVal(DataType dtype) { - if (dtype == DataType::Index) { - return zeroVal(); - } else if (isBooleanType(dtype)) { - return falseVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 0L, dtype); - } -} - -Val* IrContainer::oneVal() { - if (!one_val_) { - auto one_val = - IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == one_val); - one_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return one_val_.get(); -} - -Val* IrContainer::oneVal(DataType dtype) { - if (dtype == DataType::Index) { - return oneVal(); - } else if (isBooleanType(dtype)) { - return trueVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 1L, dtype); - } -} - -Val* IrContainer::falseVal() { - if (!false_val_) { - auto false_val = IrBuilder::createInContainer( - this->parent(), false, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == false_val); - false_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return false_val_.get(); -} - -Val* IrContainer::trueVal() { - if (!true_val_) { - auto true_val = - IrBuilder::createInContainer(this->parent(), true, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == true_val); - true_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return true_val_.get(); -} - -NamedScalar* IrContainer::magicZeroVal() { - if (!magic_zero_val_) { - auto magic_zero = - IrBuilder::create(kMagicZeroName, DataType::Index); - NVF_ERROR(vals_up_.back().get() == magic_zero); - magic_zero_val_ = std::unique_ptr( - vals_up_.back().release()->as()); - vals_up_.pop_back(); - } - return magic_zero_val_.get(); -} - -Val* IrContainer::metadataOf(Val* v) { - if (metadata_.count(v) == 0) { - auto metadata_val = - IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); - auto metadata_expr = IrBuilder::createInContainer( - this->parent(), metadata_val, v); - metadata_[v] = std::make_pair(metadata_val, metadata_expr); - } - return metadata_.at(v).first; -} - -void IrContainer::lazyInitAxioms() { - if (!axioms_) { - axioms_ = std::make_unique>(); - axioms_->reserve(kParallelTypeThreads.size() * 3); - auto zero = zeroVal(); - for (auto p : kParallelTypeThreads) { - auto pidx = NamedScalar::getParallelIndex(p); - auto pdim = NamedScalar::getParallelDim(p); - axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); - axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); - axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); - } - } -} - -void IrContainer::assumePositive(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); -} - -void IrContainer::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); -} - void IrContainer::removeStatementsCreatedAfter( int64_t prev_num_exprs, int64_t prev_num_vals) { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e361b8743ee..0ca291ea4af 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -80,30 +80,12 @@ class IrContainer { return std::ssize(exprs_); } - // When include_shortcuts is true, it will count the shortcuts like true_val_. + // When include_shortcuts is true, count cached special vals (zeroVal, etc.) + // whose ownership was transferred to Fusion but that still appear in vals_. int64_t numVals(bool include_shortcuts) const noexcept { return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); } - // Shortcuts for frequently used vals - NVF_API Val* zeroVal(); - NVF_API Val* oneVal(); - Val* falseVal(); - Val* trueVal(); - NamedScalar* magicZeroVal(); - NVF_API Val* zeroVal(DataType dtype); - NVF_API Val* oneVal(DataType dtype); - Val* metadataOf(Val*); - - // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x - const std::vector& axioms() { - lazyInitAxioms(); - return *axioms_; - } - - void assumePositive(Val* val); - void assumeNonNegative(Val* val); - protected: static IrCloner copy(const IrContainer* from, IrContainer* to); @@ -137,8 +119,6 @@ class IrContainer { void clear() noexcept; - void lazyInitAxioms(); - friend class StatementGuard; // A simple garbage collection mechanism to remove all Exprs and Vals that @@ -171,22 +151,6 @@ class IrContainer { // Expression names counter StmtNameType expr_name_counter_ = 0; - // Manually store some persistent, frequently used nodes. It's very - // challenging to do this anything but manually as detecting when a container - // may or may not have one of these vals is tricky. Specifically because if - // the container doesn't own it, it's hard to understand from the outside if - // the node may have been removed then re-registered. It could also be tricky - // to know when we're using a different container as in FusionCopy_test - // demonstrates deleting then creating containers can result in the same - // pointer for the container. - std::unique_ptr true_val_; - std::unique_ptr false_val_; - std::unique_ptr one_val_; - std::unique_ptr zero_val_; - std::unique_ptr magic_zero_val_; - std::unique_ptr> axioms_; - std::unordered_map> metadata_; - public: Fusion* parent() const { NVF_ERROR(