diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 1b9b8a8150b..1bd29d12ab0 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -116,8 +116,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { // update the parent backpointers in those containers to point to their new // owners if (a.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - a.ir_container()->parent_ = &a; for (auto val : a.vals()) { val->ir_container_ = &a; } @@ -126,8 +124,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { } } if (b.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - b.ir_container()->parent_ = &b; for (auto val : b.vals()) { val->ir_container_ = &b; } @@ -161,7 +157,8 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); + auto ir_cloner = + IrContainer::copy(from->ir_container(), to->ir_container(), to); // Remap cached special val pointers through the cloner if (from->zero_val_) { @@ -254,8 +251,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Default constructor -Fusion::Fusion() : ir_container_(std::make_unique()) { - ir_container_->parent_ = this; +Fusion::Fusion() : ir_container_(std::make_shared()) { + ir_container_->addFusion(this); } // Copy constructor @@ -287,6 +284,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept { Fusion::~Fusion() { clear(); + if (ir_container_) { + ir_container_->removeFusion(this); + } } void Fusion::clear() noexcept { @@ -350,9 +350,7 @@ void Fusion::removeExpr(Expr* expr) { auto expr_in_deque = std::find_if( c->exprs_up_.begin(), c->exprs_up_.end(), - [expr](std::unique_ptr& expr_up) { - return expr_up.get() == expr; - }); + [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); NVF_ERROR( expr_in_deque != c->exprs_up_.end(), "Wanted to remove an expression but its unique ptr is missing."); diff --git a/csrc/fusion.h b/csrc/fusion.h index e77906d3643..4c9115bf996 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -148,7 +148,6 @@ class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; protected: - // Direct access to underlying container IrContainer* ir_container() { NVF_ERROR( ir_container_.get() != nullptr, @@ -163,6 +162,10 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container_.get(); } + std::shared_ptr ir_container_ptr() const { + return ir_container_; + } + public: // Registration (public API with passkey) virtual void registerStmt(IrBuilderPasskey, Statement* stmt) { @@ -635,7 +638,7 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; - std::unique_ptr ir_container_; + std::shared_ptr ir_container_; Val* zero_val_ = nullptr; Val* one_val_ = nullptr; diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index f487a26fbcf..de5f5ded62e 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -80,14 +80,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - - std::swap(a.parent_, b.parent_); } -IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { +IrCloner IrContainer::copy( + const IrContainer* from, + IrContainer* to, + Fusion* dest_fusion) { to->clear(); - IrCloner ir_cloner(to->parent()); + IrCloner ir_cloner(dest_fusion); // Copy values in deterministic order for (auto val : from->deterministic_vals()) { @@ -138,7 +139,7 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { } NVF_ERROR( - const_stmt->container() == this->parent(), + sharing_fusions_.count(const_stmt->container()) > 0, "Container claims to own stmt, but stmt disagrees."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -157,4 +158,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } +void IrContainer::addFusion(Fusion* fusion) { + sharing_fusions_.insert(fusion); +} + +void IrContainer::removeFusion(Fusion* fusion) { + sharing_fusions_.erase(fusion); +} + +void IrContainer::transferFusion(Fusion* from, Fusion* to) { + sharing_fusions_.erase(from); + sharing_fusions_.insert(to); +} + +size_t IrContainer::sharingCount() const { + return sharing_fusions_.size(); +} + +bool IrContainer::hasMultipleFusions() const { + return sharing_fusions_.size() > 1; +} + +const std::unordered_set& IrContainer::sharingFusions() const { + return sharing_fusions_; +} + } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index a61f78be5f2..e3738be7349 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -86,7 +86,10 @@ class IrContainer { } protected: - static IrCloner copy(const IrContainer* from, IrContainer* to); + static IrCloner copy( + const IrContainer* from, + IrContainer* to, + Fusion* dest_fusion); static void swap(IrContainer& a, IrContainer& b) noexcept; @@ -127,16 +130,15 @@ class IrContainer { StmtNameType expr_name_counter_ = 0; public: - Fusion* parent() const { - NVF_ERROR( - parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.") - return parent_; - } + void addFusion(Fusion* fusion); + void removeFusion(Fusion* fusion); + void transferFusion(Fusion* from, Fusion* to); + size_t sharingCount() const; + bool hasMultipleFusions() const; + const std::unordered_set& sharingFusions() const; private: - // Parent Fusion that owns this container (for pure composition pattern) - // Used by Statement::fusion() to navigate back to owning Fusion - Fusion* parent_ = nullptr; + std::unordered_set sharing_fusions_; }; } // namespace nvfuser diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 6bb73ba9aad..24ecd966cea 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -28,6 +28,9 @@ namespace nvfuser { +// TODO: Remove when std::shared_mutex is added to IrContainer. +constexpr bool kPhase2DisableParallelCompile = true; + namespace { // Replace CUDA tensor with Meta tensor because storing tensors can cause // out-of-memory issues. Other arguments are returned as-is. @@ -454,7 +457,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { try { for (const auto& [group_to_run, group_runtime_inputs] : zip(runtime_workspace_.group_run_order, all_runtime_inputs)) { - if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups == 1 || kPhase2DisableParallelCompile || + isOptionDisabled(DisableOption::ParallelCompile)) { compileKernel(group_runtime_inputs, group_to_run); } else { // launch compileKernel thread here @@ -488,7 +492,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { throw; } - if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups != 1 && !kPhase2DisableParallelCompile && + !isOptionDisabled(DisableOption::ParallelCompile)) { // Wait until all segments finish compiling getThreadPool()->waitWorkComplete(); NVF_ERROR(