diff --git a/CMakeLists.txt b/CMakeLists.txt index d007b40c90d..c82f5609467 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -912,6 +912,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_pdl.cpp ${NVFUSER_ROOT}/tests/cpp/test_persistent_buffer.cpp + ${NVFUSER_ROOT}/tests/cpp/test_phase2_container_sharing.cpp ${NVFUSER_ROOT}/tests/cpp/test_pointwise.cpp ${NVFUSER_ROOT}/tests/cpp/test_polymorphic_value.cpp ${NVFUSER_ROOT}/tests/cpp/test_predicate_elimination.cpp diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index baf1de84614..291cfbf70fb 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -19,7 +20,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -104,39 +107,95 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrContainer swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrContainer::swap(*(a.ir_container()), *(b.ir_container())); + if (&a == &b) { + return; + } + + // Phase 2: Pointer-based swap with ownership-filtered Statement updates + // This is more efficient than content swap and correctly handles shared + // containers by only updating statements owned by the swapped Fusions. + + // Step 1: Collect statements owned by each Fusion BEFORE swap + // We need to copy to vectors because we'll be modifying ownership + std::vector a_owned_vals, b_owned_vals; + std::vector a_owned_exprs, b_owned_exprs; - // Fix parent pointers after swapping containers - // After swap, each Fusion owns a different IrContainer, so we must - // update the parent backpointers in those containers to point to their new - // owners if (a.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - a.ir_container()->parent_ = &a; - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } + const auto& a_vals = a.ir_container_->valsOwnedBy(&a); + const auto& a_exprs = a.ir_container_->exprsOwnedBy(&a); + a_owned_vals.assign(a_vals.begin(), a_vals.end()); + a_owned_exprs.assign(a_exprs.begin(), a_exprs.end()); } + if (b.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - b.ir_container()->parent_ = &b; - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } + const auto& b_vals = b.ir_container_->valsOwnedBy(&b); + const auto& b_exprs = b.ir_container_->exprsOwnedBy(&b); + b_owned_vals.assign(b_vals.begin(), b_vals.end()); + b_owned_exprs.assign(b_exprs.begin(), b_exprs.end()); } + // Step 2: Handle registration transfers (before pointer swap) + // After swap, a should be registered with its new container (was b's) + // and b should be registered with its new container (was a's) + if (a.ir_container_ && b.ir_container_ && + a.ir_container_.get() != b.ir_container_.get()) { + // Different containers: swap registrations + // In a's container: remove a, add b (because b will own this container) + // In b's container: remove b, add a (because a will own this container) + a.ir_container_->transferFusion(&a, &b); + b.ir_container_->transferFusion(&b, &a); + } + + // Step 3: Swap container pointers (not content swap!) + std::swap(a.ir_container_, b.ir_container_); + + // Step 4: Swap all Fusion-level members std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); - std::swap(a.io_alias_, b.io_alias_); + std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + std::swap(a.is_during_update_uses_, b.is_during_update_uses_); + std::swap(a.managed_data_, b.managed_data_); + std::swap(a.managed_named_data_, b.managed_named_data_); + std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); + + // Swap per-Fusion special values (Phase 2) + std::swap(a.zero_val_, b.zero_val_); + std::swap(a.one_val_, b.one_val_); + std::swap(a.true_val_, b.true_val_); + std::swap(a.false_val_, b.false_val_); + std::swap(a.magic_zero_val_, b.magic_zero_val_); + + // Swap per-Fusion axioms and metadata (Phase 2) + std::swap(a.axioms_, b.axioms_); + std::swap(a.metadata_, b.metadata_); + + // Step 5: Update statement ownership + // Statements that belonged to a now belong to b (they go with their data) + // Statements that belonged to b now belong to a + for (auto* val : a_owned_vals) { + val->ir_container_ = &b; + } + for (auto* expr : a_owned_exprs) { + expr->ir_container_ = &b; + } + for (auto* val : b_owned_vals) { + val->ir_container_ = &a; + } + for (auto* expr : b_owned_exprs) { + expr->ir_container_ = &a; + } + + // Step 6: Update per-Fusion tracking in containers + // After swap: a's new container needs to track a's new statements (was b's) + // b's new container needs to track b's new statements (was a's) + if (a.ir_container_) { + a.ir_container_->transferStatementOwnership(&b, &a); + } + if (b.ir_container_) { + b.ir_container_->transferStatementOwnership(&a, &b); + } } std::unique_ptr Fusion::segment( @@ -146,25 +205,46 @@ std::unique_ptr Fusion::segment( } IrCloner Fusion::copy(const Fusion* from, Fusion* to) { + FUSER_PERF_SCOPE("Fusion copy"); + + // Phase 2: Clear destination's state only (not entire container) + // to->clear() removes only 'to's statements from the shared container to->clear(); - auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); + // Phase 2: Create IrCloner targeting 'to' Fusion directly + // IrCloner sets cloned nodes' ir_container_ to 'to' + // This works with shared containers - clones go into the shared container + // but are tracked as owned by 'to' via per-Fusion tracking + IrCloner ir_cloner(to); + + // Phase 2: Clone only 'from's owned vals (not all vals in shared container) + // CRITICAL: Use deterministic_vals() to get vals in insertion order. + // Using ownedVals() (unordered_set) causes non-deterministic clone order, + // which assigns different name() values to cloned vals between runs. + // This breaks code that uses tv->name() as map keys (e.g., GreedyParams). + for (auto val : from->deterministic_vals()) { + ir_cloner.clone(val); + } - for (auto val : from->vals()) { + // Update definition_ and uses_ on cloned vals + for (auto val : from->deterministic_vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } + // Clone fusion inputs to->inputs_ = ir_cloner.clone(from->inputs_); - to->outputs_ = ir_cloner.clone(from->outputs_); for (auto inp : to->inputs_) { inp->setIsFusionInput(true); } + + // Clone fusion outputs + to->outputs_ = ir_cloner.clone(from->outputs_); for (auto out : to->outputs_) { out->setIsFusionOutput(true); } - // TODO: put this into ir_cloner instead + // Clone io_alias mappings for (Val* out : from->outputs_) { const AliasInfo& alias = from->io_alias_.get(out); if (alias.type == AllocationType::New) { @@ -176,10 +256,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->io_alias_.add(copied_out, copied_in, alias.type, alias.visibility); } + // Copy other Fusion-level state to->all_tv_uses_valid_ = from->all_tv_uses_valid_; // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; + // Clone managed data for (const auto& i : from->managed_data_) { if (i.first.has_value()) { to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); @@ -198,6 +280,7 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + // Clone cached TV list if present if (from->all_tvs_ptr_ != nullptr) { to->all_tvs_ptr_ = std::make_unique>(); to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); @@ -210,13 +293,24 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Default constructor -Fusion::Fusion() : ir_container_(std::make_unique()) { +Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->parent_ = this; + ir_container_->addFusion(this); // Register with container for Phase 2 } -// Copy constructor -Fusion::Fusion(const Fusion& other) : Fusion() { - FUSER_PERF_SCOPE("Fusion copy"); +// Copy constructor - Phase 2: Share container with source +Fusion::Fusion(const Fusion& other) + : ir_container_(other.ir_container_) { // Share container pointer + FUSER_PERF_SCOPE("Fusion copy ctor"); + + // Note: Special values are per-Fusion (Task 7), initialized to nullptr. + // They will be created lazily when accessed on the copy. + // Do NOT copy other.zero_val_ etc - each Fusion has its own. + + // Register with shared container + ir_container_->addFusion(this); + + // Delegate to static copy method to clone nodes Fusion::copy(&other, this); } @@ -236,12 +330,18 @@ Fusion& Fusion::operator=(const Fusion& other) { Fusion& Fusion::operator=(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move assign"); + if (this == &other) { + return *this; + } clear(); swap(*this, other); return *this; } Fusion::~Fusion() { + if (ir_container_) { + ir_container_->removeFusion(this); // Unregister before destruction + } clear(); } @@ -251,11 +351,17 @@ void Fusion::clear() noexcept { // constructor of Trace, which could throw an exception. // FUSER_PERF_SCOPE("Fusion clear"); - // Clear container contents instead of destroying it - // This preserves the container object so Statement pointers don't become - // dangling - ir_container()->clear(); + // Phase 2 (Task 4): Only clear THIS Fusion's statements, not the entire + // container. With shared containers, calling ir_container()->clear() would + // break other Fusions sharing the container. + // + // Phase 1: ir_container()->clear() was equivalent to this (1:1 relationship) + // Phase 2: Must filter by ownership to avoid affecting other Fusions + if (ir_container_) { + ir_container_->removeStatementsOwnedBy(this); + } + // Clear Fusion-level state (these are per-Fusion, not in the container) inputs_.clear(); outputs_.clear(); @@ -264,6 +370,18 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); + // Reset per-Fusion special values (they'll be recreated lazily if needed) + // The actual Val objects were removed by removeStatementsOwnedBy above. + zero_val_ = nullptr; + one_val_ = nullptr; + true_val_ = nullptr; + false_val_ = nullptr; + magic_zero_val_ = nullptr; + + // Reset per-Fusion axioms and metadata (Phase 2) + axioms_.reset(); + metadata_.clear(); + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -689,6 +807,117 @@ void Fusion::printTransforms() { t_exprs.handle(this); } +// ========================================================================= +// Per-Fusion Special Values (Phase 2) +// Each Fusion has its own special values for safe container sharing. +// ========================================================================= + +Val* Fusion::zeroVal() { + if (!zero_val_) { + zero_val_ = IrBuilder::createInContainer(this, 0L, DataType::Index); + } + return zero_val_; +} + +Val* Fusion::oneVal() { + if (!one_val_) { + one_val_ = IrBuilder::createInContainer(this, 1L, DataType::Index); + } + return one_val_; +} + +Val* Fusion::falseVal() { + if (!false_val_) { + false_val_ = IrBuilder::createInContainer(this, false, DataType::Bool); + } + return false_val_; +} + +Val* Fusion::trueVal() { + if (!true_val_) { + true_val_ = IrBuilder::createInContainer(this, true, DataType::Bool); + } + return true_val_; +} + +NamedScalar* Fusion::magicZeroVal() { + if (!magic_zero_val_) { + magic_zero_val_ = IrBuilder::createInContainer( + this, kMagicZeroName, DataType::Index); + } + return magic_zero_val_; +} + +Val* Fusion::zeroVal(DataType dtype) { + if (dtype == DataType::Index) { + return zeroVal(); + } else if (isBooleanType(dtype)) { + return falseVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 0L, dtype); + } +} + +Val* Fusion::oneVal(DataType dtype) { + if (dtype == DataType::Index) { + return oneVal(); + } else if (isBooleanType(dtype)) { + return trueVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 1L, dtype); + } +} + +// ========================================================================= +// Per-Fusion Metadata and Axioms (Phase 2) +// These are per-Fusion to avoid ownership issues with shared containers. +// ========================================================================= + +Val* Fusion::metadataOf(Val* v) { + if (metadata_.count(v) == 0) { + // Create metadata val owned by the same Fusion as v + Fusion* owner = v->container(); + auto metadata_val = + IrBuilder::createInContainer(owner, metaDataTypeOf(v)); + auto metadata_expr = + IrBuilder::createInContainer(owner, metadata_val, v); + metadata_[v] = std::make_pair(metadata_val, metadata_expr); + } + return metadata_.at(v).first; +} + +const std::vector& Fusion::axioms() { + if (!axioms_) { + axioms_ = std::make_unique>(); + axioms_->reserve(kParallelTypeThreads.size() * 3); + auto zero = zeroVal(); + for (auto p : kParallelTypeThreads) { + auto pidx = NamedScalar::getParallelIndex(p); + auto pdim = NamedScalar::getParallelDim(p); + axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); + axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); + axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); + } + } + return *axioms_; +} + +void Fusion::assumePositive(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); +} + +void Fusion::assumeNonNegative(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; diff --git a/csrc/fusion.h b/csrc/fusion.h index f02c1b0310d..886191607c6 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -59,6 +60,7 @@ namespace nvfuser { //! checks. class Fusion; +class NamedScalar; class TensorView; class SegmentCandidateFinder; @@ -146,8 +148,8 @@ class AliasInfoMap { class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; - protected: - // Direct access to underlying container + public: + // Direct access to underlying container (for Phase 2 shared_ptr support) IrContainer* ir_container() { NVF_ERROR( ir_container_.get() != nullptr, @@ -162,7 +164,11 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container_.get(); } - public: + // Return the shared_ptr to the container (for Phase 2 container sharing) + std::shared_ptr ir_container_ptr() const { + return ir_container_; + } + // Registration (public API with passkey) virtual void registerStmt(IrBuilderPasskey, Statement* stmt) { if (stmt->isVal()) { @@ -517,94 +523,95 @@ class NVF_API Fusion : public PolymorphicBase { } // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_container()->deterministic_vals(); + // Phase 2: These return only statements owned by THIS Fusion, + // not all statements in the (possibly shared) container. + std::deque deterministic_vals() const noexcept { + return ir_container()->deterministicValsOwnedBy(const_cast(this)); } - const std::deque deterministic_exprs() const noexcept { - return ir_container()->deterministic_exprs(); + std::deque deterministic_exprs() const noexcept { + return ir_container()->deterministicExprsOwnedBy(const_cast(this)); } - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_container()->deterministic_vals_map(); + std::unordered_map deterministic_vals_map() const noexcept { + return ir_container()->deterministicValsMapOwnedBy( + const_cast(this)); } - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_container()->deterministic_exprs_map(); + std::unordered_map deterministic_exprs_map() const noexcept { + return ir_container()->deterministicExprsMapOwnedBy( + const_cast(this)); } // Collections access (unordered sets) const std::unordered_set& unordered_exprs() const noexcept { - return ir_container()->unordered_exprs(); + return ownedExprs(); } const std::unordered_set& vals() const noexcept { - return ir_container()->vals(); - } - - // Count queries - int64_t numExprs() const noexcept { - return ir_container()->numExprs(); - } - - int64_t numVals(bool include_shortcuts) const noexcept { - return ir_container()->numVals(include_shortcuts); - } - - // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_container()->zeroVal(); - } - - Val* oneVal() { - return ir_container()->oneVal(); + return ownedVals(); } - Val* falseVal() { - return ir_container()->falseVal(); - } + // Per-Fusion Statement Access (Phase 2 Task 4) + // These methods return only statements owned by THIS Fusion, + // not all statements in the shared container. + // + // Phase 1 (unique_ptr): ownedVals() == vals() (1:1 relationship) + // Phase 2 (shared_ptr): ownedVals() ⊆ vals() (filtering by ownership) - Val* trueVal() { - return ir_container()->trueVal(); + //! Return only Vals owned by this Fusion + //! Unlike vals() which returns ALL vals in the (possibly shared) container, + //! this returns only vals where val->container() == this + const std::unordered_set& ownedVals() const { + return ir_container()->valsOwnedBy(const_cast(this)); } - NamedScalar* magicZeroVal() { - return ir_container()->magicZeroVal(); + //! Return only Exprs owned by this Fusion + //! Unlike unordered_exprs() which returns ALL exprs in the container, + //! this returns only exprs where expr->container() == this + const std::unordered_set& ownedExprs() const { + return ir_container()->exprsOwnedBy(const_cast(this)); } - Val* zeroVal(DataType dtype) { - return ir_container()->zeroVal(dtype); + // Count queries + int64_t numExprs() const noexcept { + return ownedExprs().size(); } - Val* oneVal(DataType dtype) { - return ir_container()->oneVal(dtype); + int64_t numVals(bool include_shortcuts) const noexcept { + return ownedVals().size(); } - Val* metadataOf(Val* val) { - return ir_container()->metadataOf(val); - } + // Shortcut values (frequently used constants) + // Phase 2: These are now per-Fusion with lazy creation. + // Each Fusion has its own special values to avoid ownership conflicts + // when multiple Fusions share an IrContainer. + Val* zeroVal(); + Val* oneVal(); + Val* falseVal(); + Val* trueVal(); + NamedScalar* magicZeroVal(); + Val* zeroVal(DataType dtype); + Val* oneVal(DataType dtype); + + // Phase 2: Per-Fusion metadata and axioms + // These are now per-Fusion to avoid ownership issues with shared containers. + Val* metadataOf(Val* val); // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_container()->axioms(); - } + const std::vector& axioms(); - void assumePositive(Val* val) { - ir_container()->assumePositive(val); - } - - void assumeNonNegative(Val* val) { - ir_container()->assumeNonNegative(val); - } + void assumePositive(Val* val); + void assumeNonNegative(Val* val); // Statement removal + // Phase 2: Now takes this Fusion as parameter to properly handle + // shared containers. Only removes statements owned by this Fusion. void removeStatementsCreatedAfter( int64_t num_exprs_before, int64_t num_vals_before) { ir_container()->removeStatementsCreatedAfter( - num_exprs_before, num_vals_before); + this, num_exprs_before, num_vals_before); } protected: @@ -666,7 +673,25 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; - std::unique_ptr ir_container_; + std::shared_ptr ir_container_; + + // Phase 2: Per-Fusion special values + // With shared containers, each Fusion needs its own special values. + // These are raw pointers - memory is owned by IrContainer's vals_up_. + // Destroying this Fusion removes these vals via removeStatementsOwnedBy(). + Val* zero_val_ = nullptr; + Val* one_val_ = nullptr; + Val* true_val_ = nullptr; + Val* false_val_ = nullptr; + NamedScalar* magic_zero_val_ = nullptr; + + // Phase 2: Per-Fusion axioms (CUDA programming assumptions) + // These are per-Fusion to avoid ownership issues with shared containers. + std::unique_ptr> axioms_; + + // Phase 2: Per-Fusion metadata cache + // Maps Val* to (metadata_val, metadata_expr) pairs + std::unordered_map> metadata_; }; // Template implementations for Fusion::manage() that use IrCloner @@ -719,7 +744,13 @@ T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - if (src_container != dest_container) { + // Phase 2 Task 10: For same-container cloning (shared IrContainer), + // per-Fusion name counters produce matching names naturally (both start + // at 0), so the name override below is NOT needed and is skipped. + // For cross-container cloning (different IrContainers), we still need + // to force the source name since the destination's global counter may + // have diverged. + if (src_container->ir_container() != dest_container->ir_container()) { dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); } diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 3c54966c87d..2ddf9b9d2d5 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -7,6 +7,7 @@ // clang-format on #include "ir/container.h" +#include "fusion.h" #include "instrumentation.h" #include "ir/base_nodes.h" #include "ir/builder.h" @@ -17,6 +18,7 @@ namespace nvfuser { //! Return values in insertion order const std::deque IrContainer::deterministic_vals() const noexcept { + std::shared_lock lock(mutex_); std::deque vals_deque; std::transform( vals_up_.begin(), @@ -28,6 +30,7 @@ const std::deque IrContainer::deterministic_vals() const noexcept { //! Return expression in insertion order const std::deque IrContainer::deterministic_exprs() const noexcept { + std::shared_lock lock(mutex_); std::deque exprs_deque; std::transform( exprs_up_.begin(), @@ -40,6 +43,7 @@ const std::deque IrContainer::deterministic_exprs() const noexcept { //! Return mapping from value to integer id const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map vals_map; int64_t count = 0; std::transform( @@ -55,6 +59,7 @@ const std::unordered_map IrContainer::deterministic_vals_map() //! Return mapping from expression to integer id const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map exprs_map; int64_t count = 0; std::transform( @@ -67,8 +72,272 @@ const std::unordered_map IrContainer::deterministic_exprs_map() return exprs_map; } +// ========================================================================= +// Per-Fusion Deterministic Accessors (Phase 2) +// ========================================================================= + +std::deque IrContainer::deterministicValsOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::deque result; + + // Get the set of vals owned by this Fusion for O(1) lookup + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; // Empty - no vals owned by this Fusion + } + const auto& owned_vals = it->second; + + // Iterate in insertion order, filtering to only owned vals + for (const auto& val_up : vals_up_) { + Val* val = val_up.get(); + if (owned_vals.count(val) > 0) { + result.push_back(val); + } + } + return result; +} + +std::deque IrContainer::deterministicExprsOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::deque result; + + // Get the set of exprs owned by this Fusion for O(1) lookup + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; // Empty - no exprs owned by this Fusion + } + const auto& owned_exprs = it->second; + + // Iterate in insertion order, filtering to only owned exprs + for (const auto& expr_up : exprs_up_) { + Expr* expr = expr_up.get(); + if (owned_exprs.count(expr) > 0) { + result.push_back(expr); + } + } + return result; +} + +std::unordered_map IrContainer::deterministicValsMapOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::unordered_map result; + + // Get the set of vals owned by this Fusion for O(1) lookup + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; // Empty - no vals owned by this Fusion + } + const auto& owned_vals = it->second; + + // Iterate in insertion order, assigning sequential ids to owned vals + int64_t count = 0; + for (const auto& val_up : vals_up_) { + Val* val = val_up.get(); + if (owned_vals.count(val) > 0) { + result[val] = count++; + } + } + return result; +} + +std::unordered_map IrContainer::deterministicExprsMapOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::unordered_map result; + + // Get the set of exprs owned by this Fusion for O(1) lookup + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; // Empty - no exprs owned by this Fusion + } + const auto& owned_exprs = it->second; + + // Iterate in insertion order, assigning sequential ids to owned exprs + int64_t count = 0; + for (const auto& expr_up : exprs_up_) { + Expr* expr = expr_up.get(); + if (owned_exprs.count(expr) > 0) { + result[expr] = count++; + } + } + return result; +} + +const std::unordered_set& IrContainer::unordered_exprs() const noexcept { + // Note: Returns reference - caller responsible for not holding across + // concurrent modifications. Lock provides snapshot consistency during call. + std::shared_lock lock(mutex_); + return exprs_; +} + +const std::unordered_set& IrContainer::vals() const noexcept { + // Note: Returns reference - caller responsible for not holding across + // concurrent modifications. Lock provides snapshot consistency during call. + std::shared_lock lock(mutex_); + return vals_; +} + +int64_t IrContainer::numExprs() const noexcept { + std::shared_lock lock(mutex_); + return std::ssize(exprs_); +} + +int64_t IrContainer::numVals(bool include_shortcuts) const noexcept { + std::shared_lock lock(mutex_); + return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); +} + +// ========================================================================= +// Fusion tracking for shared container support (Phase 2) +// ========================================================================= + +void IrContainer::addFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); + sharing_fusions_.insert(fusion); +} + +void IrContainer::removeFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); + sharing_fusions_.erase(fusion); + removeStatementsOwnedByUnlocked(fusion); +} + +void IrContainer::transferFusion(Fusion* from, Fusion* to) { + std::unique_lock lock(mutex_); + sharing_fusions_.erase(from); + sharing_fusions_.insert(to); + // Note: Statements retain their container() pointer - they don't need + // to be updated because container() returns Fusion* which points to + // the owning Fusion, and that ownership is what we're transferring. +} + +size_t IrContainer::sharingCount() const { + std::shared_lock lock(mutex_); + return sharing_fusions_.size(); +} + +bool IrContainer::hasMultipleFusions() const { + std::shared_lock lock(mutex_); + return sharing_fusions_.size() > 1; +} + +const std::unordered_set& IrContainer::sharingFusions() const { + std::shared_lock lock(mutex_); + return sharing_fusions_; +} + +// ========================================================================= +// Per-Fusion Statement Tracking (Phase 2 Task 4) +// ========================================================================= + +const std::unordered_set& IrContainer::valsOwnedBy(Fusion* fusion) const { + std::shared_lock lock(mutex_); + static const std::unordered_set empty; + auto it = per_fusion_vals_.find(fusion); + return it != per_fusion_vals_.end() ? it->second : empty; +} + +const std::unordered_set& IrContainer::exprsOwnedBy( + Fusion* fusion) const { + std::shared_lock lock(mutex_); + static const std::unordered_set empty; + auto it = per_fusion_exprs_.find(fusion); + return it != per_fusion_exprs_.end() ? it->second : empty; +} + +void IrContainer::transferStatementOwnership(Fusion* from, Fusion* to) { + std::unique_lock lock(mutex_); + + // Transfer vals ownership tracking + auto vals_it = per_fusion_vals_.find(from); + if (vals_it != per_fusion_vals_.end()) { + // Move the set to 'to', merging if 'to' already has entries + auto& to_vals = per_fusion_vals_[to]; + to_vals.insert(vals_it->second.begin(), vals_it->second.end()); + per_fusion_vals_.erase(vals_it); + } + + // Transfer exprs ownership tracking + auto exprs_it = per_fusion_exprs_.find(from); + if (exprs_it != per_fusion_exprs_.end()) { + // Move the set to 'to', merging if 'to' already has entries + auto& to_exprs = per_fusion_exprs_[to]; + to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end()); + per_fusion_exprs_.erase(exprs_it); + } + + // Transfer per-Fusion name counters (Phase 2 Task 10) + auto val_names_it = per_fusion_val_name_map_.find(from); + if (val_names_it != per_fusion_val_name_map_.end()) { + // Merge counter maps: take max of each ValType counter + auto& to_map = per_fusion_val_name_map_[to]; + for (auto& [vtype, counter] : val_names_it->second) { + to_map[vtype] = std::max(to_map[vtype], counter); + } + per_fusion_val_name_map_.erase(val_names_it); + } + + auto expr_names_it = per_fusion_expr_name_counter_.find(from); + if (expr_names_it != per_fusion_expr_name_counter_.end()) { + auto& to_counter = per_fusion_expr_name_counter_[to]; + to_counter = std::max(to_counter, expr_names_it->second); + per_fusion_expr_name_counter_.erase(expr_names_it); + } +} + +void IrContainer::removeStatementsOwnedBy(Fusion* fusion) { + std::unique_lock lock(mutex_); + removeStatementsOwnedByUnlocked(fusion); +} + +void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { + // Remove all Vals owned by this Fusion + for (auto it = vals_up_.begin(); it != vals_up_.end();) { + Val* val = it->get(); + // Check if this Val's container points to the Fusion being removed + if (val->container() == fusion) { + vals_.erase(val); + it = vals_up_.erase(it); + } else { + ++it; + } + } + + // Remove all Exprs owned by this Fusion + for (auto it = exprs_up_.begin(); it != exprs_up_.end();) { + Expr* expr = it->get(); + // Check if this Expr's container points to the Fusion being removed + if (expr->container() == fusion) { + exprs_.erase(expr); + it = exprs_up_.erase(it); + } else { + ++it; + } + } + + // Clean up per-Fusion tracking (Phase 2 Task 4) + per_fusion_vals_.erase(fusion); + per_fusion_exprs_.erase(fusion); + + // Clean up per-Fusion name counters (Phase 2 Task 10) + per_fusion_val_name_map_.erase(fusion); + per_fusion_expr_name_counter_.erase(fusion); +} + void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { - FUSER_PERF_SCOPE("Fusion swap"); + FUSER_PERF_SCOPE("IrContainer swap"); + + // NOTE: This method is deprecated in Phase 2. Fusion::swap handles + // pointer-based swapping of shared containers. This is kept for + // backward compatibility but should not be called directly. + + // Lock both containers in consistent order to avoid deadlock + std::unique_lock lock_a(a.mutex_, std::defer_lock); + std::unique_lock lock_b(b.mutex_, std::defer_lock); + std::lock(lock_a, lock_b); // Swap the content std::swap(a.vals_up_, b.vals_up_); @@ -80,34 +349,70 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - std::swap(a.metadata_, b.metadata_); - std::swap(a.parent_, b.parent_); - std::swap(a.zero_val_, b.zero_val_); - std::swap(a.one_val_, b.one_val_); - std::swap(a.true_val_, b.true_val_); - std::swap(a.false_val_, b.false_val_); - std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); + // Note: Special values, axioms, and metadata are now per-Fusion, + // not per-IrContainer. They are handled by Fusion::swap. + std::swap(a.sharing_fusions_, b.sharing_fusions_); + std::swap(a.per_fusion_vals_, b.per_fusion_vals_); + std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); + + // Swap per-Fusion name counters (Phase 2 Task 10) + std::swap(a.per_fusion_val_name_map_, b.per_fusion_val_name_map_); + std::swap(a.per_fusion_expr_name_counter_, b.per_fusion_expr_name_counter_); } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { - to->clear(); + // NOTE: This method is deprecated in Phase 2. Fusion::copy handles + // copying with shared containers. This is kept for backward compatibility + // but should not be called directly. + + // Lock both containers: shared for reading from, unique for writing to + std::shared_lock lock_from(from->mutex_); + std::unique_lock lock_to(to->mutex_); + + // Clear without calling clear() which would try to re-acquire the lock + to->vals_.clear(); + to->vals_up_.clear(); + to->exprs_.clear(); + to->exprs_up_.clear(); + to->val_type_name_map_.clear(); + to->expr_name_counter_ = 0; + to->per_fusion_vals_.clear(); + to->per_fusion_exprs_.clear(); + to->per_fusion_val_name_map_.clear(); + to->per_fusion_expr_name_counter_.clear(); + + // NOTE: In Phase 2, we can't use to->parent() here because parent_ might + // not be set correctly for shared containers. Fusion::copy handles this. + NVF_ERROR( + to->parent_ != nullptr, + "IrContainer::copy requires parent_ to be set. Use Fusion::copy " + "instead."); IrCloner ir_cloner(to->parent()); // Copy values in deterministic order - // deterministic_vals can contain special values like one_val_, zero_val_, etc - // that are not registered in the container. - for (auto val : from->deterministic_vals()) { - if (from->vals().count(val) > 0) { + std::deque from_vals; + std::transform( + from->vals_up_.begin(), + from->vals_up_.end(), + std::back_inserter(from_vals), + [](const std::unique_ptr& val_up) { return val_up.get(); }); + for (auto val : from_vals) { + if (from->vals_.count(val) > 0) { to->vals_.insert(ir_cloner.clone(val)); } } // Copy expressions in deterministic order - for (auto expr : from->deterministic_exprs()) { - if (from->unordered_exprs().count(expr) > 0) { + std::deque from_exprs; + std::transform( + from->exprs_up_.begin(), + from->exprs_up_.end(), + std::back_inserter(from_exprs), + [](const std::unique_ptr& expr_up) { return expr_up.get(); }); + for (auto expr : from_exprs) { + if (from->exprs_.count(expr) > 0) { to->exprs_.insert(ir_cloner.clone(expr)); } } @@ -115,14 +420,7 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; - if (from->axioms_ != nullptr) { - to->axioms_ = std::make_unique>(); - for (auto pred : *from->axioms_) { - to->axioms_->push_back(ir_cloner.clone(pred)); - } - } - - to->metadata_ = ir_cloner.clone(from->metadata_); + // Note: axioms and metadata are now per-Fusion, handled by Fusion::copy return ir_cloner; } @@ -146,6 +444,14 @@ void IrContainer::removeExpr(Expr* expr) { expr_in_deque != exprs_up_.end(), "Wanted to remove an expression but its unique ptr is missing."); + // Remove from per-Fusion tracking (Phase 2 Task 4) + if (expr->container() != nullptr) { + auto it = per_fusion_exprs_.find(expr->container()); + if (it != per_fusion_exprs_.end()) { + it->second.erase(expr); + } + } + exprs_.erase(expr); exprs_up_.erase(expr_in_deque); } @@ -153,12 +459,9 @@ void IrContainer::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it void IrContainer::removeVal(Val* val) { - // Don't remove shortcuts - if (val == true_val_.get() || val == false_val_.get() || - val == one_val_.get() || val == zero_val_.get() || - val == magic_zero_val_.get()) { - return; - } + // Note: Special values (zero_val_, one_val_, etc.) are now per-Fusion, + // stored in Fusion class. They are registered as normal vals and can + // be removed like any other val. NVF_ERROR( vals_.find(val) != vals_.end(), @@ -172,6 +475,14 @@ void IrContainer::removeVal(Val* val) { val_in_deque != vals_up_.end(), "Wanted to remove a value but its unique ptr is missing."); + // Remove from per-Fusion tracking (Phase 2 Task 4) + if (val->container() != nullptr) { + auto it = per_fusion_vals_.find(val->container()); + if (it != per_fusion_vals_.end()) { + it->second.erase(val); + } + } + vals_.erase(val); vals_up_.erase(val_in_deque); } @@ -185,7 +496,17 @@ void IrContainer::registerVal(Val* val) { // Otherwise handle registration locally vals_up_.emplace_back(val); vals_.insert(val); - val->setName(IrContainerPasskey(), getValName(val->vtype())); + + // Phase 2 Task 10: Use per-Fusion counter if val has an owning Fusion. + // This ensures cloned Fusions get matching names (T0=T0, T1=T1) + // instead of incrementing global names (T0=T10, T1=T11). + Fusion* owning_fusion = val->container(); + val->setName(IrContainerPasskey(), getValName(owning_fusion, val->vtype())); + + // Track per-Fusion ownership (Phase 2 Task 4) + if (owning_fusion != nullptr) { + per_fusion_vals_[owning_fusion].insert(val); + } } //! Register expr with this container. @@ -197,7 +518,15 @@ void IrContainer::registerExpr(Expr* expr) { // Otherwise handle registration locally exprs_up_.emplace_back(expr); exprs_.insert(expr); - expr->setName(IrContainerPasskey(), getExprName()); + + // Phase 2 Task 10: Use per-Fusion counter if expr has an owning Fusion. + Fusion* owning_fusion = expr->container(); + expr->setName(IrContainerPasskey(), getExprName(owning_fusion)); + + // Track per-Fusion ownership (Phase 2 Task 4) + if (owning_fusion != nullptr) { + per_fusion_exprs_[owning_fusion].insert(expr); + } } void IrContainer::clear() noexcept { @@ -206,10 +535,16 @@ void IrContainer::clear() noexcept { vals_up_.clear(); exprs_.clear(); exprs_up_.clear(); - axioms_.reset(); val_type_name_map_.clear(); - metadata_.clear(); expr_name_counter_ = 0; + + // Clear per-Fusion tracking (Phase 2 Task 4) + per_fusion_vals_.clear(); + per_fusion_exprs_.clear(); + + // Clear per-Fusion name counters (Phase 2 Task 10) + per_fusion_val_name_map_.clear(); + per_fusion_expr_name_counter_.clear(); } bool IrContainer::inContainer(const Statement* const_stmt) const { @@ -224,9 +559,15 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return false; } + // Phase 2: With shared containers, multiple Fusions can share this container. + // The statement's container() returns its owning Fusion, which should be + // one of the Fusions sharing this container. + // Phase 1 (single Fusion): sharing_fusions_ == {parent_} + // Phase 2 (shared container): sharing_fusions_ contains multiple Fusions NVF_ERROR( - const_stmt->container() == this->parent(), - "Container claims to own stmt, but stmt disagrees."); + sharing_fusions_.count(const_stmt->container()) > 0, + "Container claims to own stmt, but stmt's owning Fusion is not " + "registered with this container."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* stmt = const_cast(const_stmt); @@ -244,154 +585,78 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } -// Shortcuts for frequently used vals -Val* IrContainer::zeroVal() { - if (!zero_val_) { - auto zero_val = - IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == zero_val); - zero_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return zero_val_.get(); -} +// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal), +// metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(), +// Fusion::metadataOf(), Fusion::axioms(), etc. instead. +// This avoids ownership conflicts when multiple Fusions share an IrContainer. -Val* IrContainer::zeroVal(DataType dtype) { - if (dtype == DataType::Index) { - return zeroVal(); - } else if (isBooleanType(dtype)) { - return falseVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 0L, dtype); - } -} +void IrContainer::removeStatementsCreatedAfter( + Fusion* fusion, + int64_t prev_num_exprs, + int64_t prev_num_vals) { + std::unique_lock lock(mutex_); -Val* IrContainer::oneVal() { - if (!one_val_) { - auto one_val = - IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == one_val); - one_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return one_val_.get(); -} + // Phase 2: Remove only statements owned by the specified Fusion. + // This correctly handles shared containers where multiple Fusions + // have statements interleaved in the container's deques. -Val* IrContainer::oneVal(DataType dtype) { - if (dtype == DataType::Index) { - return oneVal(); - } else if (isBooleanType(dtype)) { - return trueVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 1L, dtype); - } -} + // Get current per-Fusion counts + auto vals_it = per_fusion_vals_.find(fusion); + auto exprs_it = per_fusion_exprs_.find(fusion); -Val* IrContainer::falseVal() { - if (!false_val_) { - auto false_val = IrBuilder::createInContainer( - this->parent(), false, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == false_val); - false_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return false_val_.get(); -} + int64_t current_fusion_exprs = + (exprs_it != per_fusion_exprs_.end()) ? exprs_it->second.size() : 0; + int64_t current_fusion_vals = + (vals_it != per_fusion_vals_.end()) ? vals_it->second.size() : 0; -Val* IrContainer::trueVal() { - if (!true_val_) { - auto true_val = - IrBuilder::createInContainer(this->parent(), true, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == true_val); - true_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return true_val_.get(); -} + // Calculate how many statements to remove from this Fusion + int64_t exprs_to_remove = current_fusion_exprs - prev_num_exprs; + int64_t vals_to_remove = current_fusion_vals - prev_num_vals; -NamedScalar* IrContainer::magicZeroVal() { - if (!magic_zero_val_) { - auto magic_zero = - IrBuilder::create(kMagicZeroName, DataType::Index); - NVF_ERROR(vals_up_.back().get() == magic_zero); - magic_zero_val_ = std::unique_ptr( - vals_up_.back().release()->as()); - vals_up_.pop_back(); + if (exprs_to_remove <= 0 && vals_to_remove <= 0) { + return; // Nothing to remove } - return magic_zero_val_.get(); -} -Val* IrContainer::metadataOf(Val* v) { - if (metadata_.count(v) == 0) { - auto metadata_val = - IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); - auto metadata_expr = IrBuilder::createInContainer( - this->parent(), metadata_val, v); - metadata_[v] = std::make_pair(metadata_val, metadata_expr); - } - return metadata_.at(v).first; -} - -void IrContainer::lazyInitAxioms() { - if (!axioms_) { - axioms_ = std::make_unique>(); - axioms_->reserve(kParallelTypeThreads.size() * 3); - auto zero = zeroVal(); - for (auto p : kParallelTypeThreads) { - auto pidx = NamedScalar::getParallelIndex(p); - auto pdim = NamedScalar::getParallelDim(p); - axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); - axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); - axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); + // Remove expressions owned by this Fusion (from back of deque) + // We iterate backwards and remove only those owned by this Fusion + int64_t exprs_removed = 0; + // Use index-based iteration to avoid iterator invalidation issues + for (int64_t i = static_cast(exprs_up_.size()) - 1; + i >= 0 && exprs_removed < exprs_to_remove; + --i) { + Expr* e = exprs_up_[i].get(); + if (e->container() == fusion) { + // Clean up use-def chains + for (Val* in : e->inputs()) { + in->removeUse(e); + } + // Remove from tracking sets + exprs_.erase(e); + if (exprs_it != per_fusion_exprs_.end()) { + exprs_it->second.erase(e); + } + // Erase from deque + exprs_up_.erase(exprs_up_.begin() + i); + exprs_removed++; } } -} - -void IrContainer::assumePositive(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); -} -void IrContainer::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); -} - -void IrContainer::removeStatementsCreatedAfter( - int64_t prev_num_exprs, - int64_t prev_num_vals) { - NVF_ERROR( - exprs_up_.size() == exprs_.size(), - "exprs_up_ (size ", - exprs_up_.size(), - ") and exprs_ (size ", - exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(exprs_up_) >= prev_num_exprs, - "exprs_up_ size (", - std::ssize(exprs_up_), - ") is less than prev_num_exprs (", - prev_num_exprs, - ")."); - - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(exprs_up_) > prev_num_exprs) { - Expr* e = exprs_up_.back().get(); - for (Val* in : e->inputs()) { - in->removeUse(e); + // Remove vals owned by this Fusion (from back of deque) + int64_t vals_removed = 0; + for (int64_t i = static_cast(vals_up_.size()) - 1; + i >= 0 && vals_removed < vals_to_remove; + --i) { + Val* v = vals_up_[i].get(); + if (v->container() == fusion) { + // Remove from tracking sets + vals_.erase(v); + if (vals_it != per_fusion_vals_.end()) { + vals_it->second.erase(v); + } + // Erase from deque + vals_up_.erase(vals_up_.begin() + i); + vals_removed++; } - exprs_.erase(e); - exprs_up_.pop_back(); - } - - while (std::ssize(vals_up_) > prev_num_vals) { - vals_.erase(vals_up_.back().get()); - vals_up_.pop_back(); } } diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e361b8743ee..6a5666994ea 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -30,7 +31,7 @@ class NamedScalar; class IrContainer { public: - NVF_API IrContainer(); + IrContainer(); // Copy/Move Constructors and Operators are deleted. IrContainer is managed // through a smart pointer in IrContainer. Semantic operations for Fusion @@ -64,47 +65,114 @@ class IrContainer { const std::unordered_map deterministic_exprs_map() const noexcept; + // ========================================================================= + // Per-Fusion Deterministic Accessors (Phase 2) + // These return statements in insertion order, filtered by ownership. + // ========================================================================= + + //! Return values owned by a specific Fusion in insertion order + std::deque deterministicValsOwnedBy(Fusion* fusion) const noexcept; + + //! Return expressions owned by a specific Fusion in insertion order + std::deque deterministicExprsOwnedBy(Fusion* fusion) const noexcept; + + //! Return mapping from value to integer id for values owned by a Fusion + //! The integer ids are local to this Fusion's values (0, 1, 2, ...) + std::unordered_map deterministicValsMapOwnedBy( + Fusion* fusion) const noexcept; + + //! Return mapping from expression to integer id for exprs owned by a Fusion + //! The integer ids are local to this Fusion's exprs (0, 1, 2, ...) + std::unordered_map deterministicExprsMapOwnedBy( + Fusion* fusion) const noexcept; + //! Return the set of Exprs registered with this fusion. Warning: This will //! return exprs outside inputs/outputs, so can be unsafe for use with //! segmented fusions. - const std::unordered_set& unordered_exprs() const noexcept { - return exprs_; - } + //! Note: Returns reference - caller must not hold across concurrent mods + const std::unordered_set& unordered_exprs() const noexcept; //! Return the set of Vals registered with this fusion - const std::unordered_set& vals() const noexcept { - return vals_; - } + //! Note: Returns reference - caller must not hold across concurrent mods + const std::unordered_set& vals() const noexcept; - int64_t numExprs() const noexcept { - return std::ssize(exprs_); - } + int64_t numExprs() const noexcept; - // When include_shortcuts is true, it will count the shortcuts like true_val_. + // Note: The include_shortcuts parameter is now deprecated. + // With Phase 2 per-Fusion special values, all vals (including special values) + // are stored in vals_up_, so both vals_ and vals_up_ have the same size. + // This parameter is kept for API compatibility but has no effect. int64_t numVals(bool include_shortcuts) const noexcept { return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); } - // Shortcuts for frequently used vals - NVF_API Val* zeroVal(); - NVF_API Val* oneVal(); - Val* falseVal(); - Val* trueVal(); - NamedScalar* magicZeroVal(); - NVF_API Val* zeroVal(DataType dtype); - NVF_API Val* oneVal(DataType dtype); - Val* metadataOf(Val*); - - // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x - const std::vector& axioms() { - lazyInitAxioms(); - return *axioms_; - } + // Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal), + // metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(), + // Fusion::metadataOf(), Fusion::axioms(), etc. instead. + // This avoids ownership conflicts when multiple Fusions share an IrContainer. + + public: + // ========================================================================= + // Fusion tracking for shared container support (Phase 2) + // ========================================================================= + + //! Register a Fusion as sharing this container + void addFusion(Fusion* fusion); + + //! Unregister a Fusion and cleanup its owned Statements + void removeFusion(Fusion* fusion); + + //! Transfer registration from one Fusion to another (for move operations) + void transferFusion(Fusion* from, Fusion* to); + + //! Number of Fusions sharing this container + size_t sharingCount() const; - void assumePositive(Val* val); - void assumeNonNegative(Val* val); + //! Whether multiple Fusions share this container + bool hasMultipleFusions() const; + + //! Get the set of Fusions sharing this container + const std::unordered_set& sharingFusions() const; + + // ========================================================================= + // Per-Fusion Statement Tracking (Phase 2 Task 4) + // ========================================================================= + + //! Get Vals owned by a specific Fusion + //! Returns empty set if Fusion has no vals in this container + NVF_API const std::unordered_set& valsOwnedBy(Fusion* fusion) const; + + //! Get Exprs owned by a specific Fusion + //! Returns empty set if Fusion has no exprs in this container + const std::unordered_set& exprsOwnedBy(Fusion* fusion) const; + + //! Transfer statement ownership tracking from one Fusion to another + //! Used during move operations + void transferStatementOwnership(Fusion* from, Fusion* to); + + //! Public version of removeStatementsOwnedBy (acquires lock) + //! Removes all Statements owned by a specific Fusion + void removeStatementsOwnedBy(Fusion* fusion); protected: + // Mutex for thread-safe access when container is shared between Fusions + // mutable because we need to lock in const methods + mutable std::shared_mutex mutex_; + + //! Fusions that share this container (for Phase 2 shared_ptr ownership) + std::unordered_set sharing_fusions_; + + //! Per-Fusion statement tracking for efficient ownership queries + //! Maps each Fusion to the set of Vals it owns in this container + std::unordered_map> per_fusion_vals_; + + //! Maps each Fusion to the set of Exprs it owns in this container + std::unordered_map> per_fusion_exprs_; + + //! Remove all Statements owned by a specific Fusion (internal helper) + //! Caller must hold unique_lock on mutex_ + void removeStatementsOwnedByUnlocked(Fusion* fusion); + static IrCloner copy(const IrContainer* from, IrContainer* to); static void swap(IrContainer& a, IrContainer& b) noexcept; @@ -124,21 +192,48 @@ class IrContainer { //! Register expr with this container. NVF_API void registerExpr(Expr* expr); - StmtNameType getValName(ValType vtype) { + //! Get next val name, using per-Fusion counter if fusion is non-null, + //! falling back to global counter otherwise. + //! Per-Fusion counters ensure cloned Fusions produce matching names. + StmtNameType getValName(Fusion* fusion, ValType vtype) { + if (fusion != nullptr) { + auto& name_map = per_fusion_val_name_map_[fusion]; + if (name_map.find(vtype) == name_map.end()) { + name_map[vtype] = 0; + } + // Also advance global counter to keep it >= all per-Fusion counters + // This prevents conflicts if global counter is used later + auto& global = val_type_name_map_[vtype]; + auto per_fusion_name = name_map[vtype]++; + if (global <= per_fusion_name) { + global = per_fusion_name + 1; + } + return per_fusion_name; + } + // Global fallback for non-Fusion contexts if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { val_type_name_map_[vtype] = 0; } return val_type_name_map_[vtype]++; } - StmtNameType getExprName() { + //! Get next expr name, using per-Fusion counter if fusion is non-null, + //! falling back to global counter otherwise. + StmtNameType getExprName(Fusion* fusion) { + if (fusion != nullptr) { + auto& counter = per_fusion_expr_name_counter_[fusion]; + auto per_fusion_name = counter++; + // Also advance global counter + if (expr_name_counter_ <= per_fusion_name) { + expr_name_counter_ = per_fusion_name + 1; + } + return per_fusion_name; + } return expr_name_counter_++; } void clear() noexcept; - void lazyInitAxioms(); - friend class StatementGuard; // A simple garbage collection mechanism to remove all Exprs and Vals that @@ -147,7 +242,13 @@ class IrContainer { // itself. // // Used by StatementGuard only. + // + // Phase 2 Note: This method now takes a Fusion pointer to properly handle + // shared containers. It removes only statements owned by the specified + // Fusion that were created after the snapshot point, preserving statements + // owned by other Fusions sharing the container. void removeStatementsCreatedAfter( + Fusion* fusion, int64_t prev_num_exprs, int64_t prev_num_vals); @@ -165,27 +266,26 @@ class IrContainer { // something like check if an Expr is in this container std::unordered_set exprs_; - // Values names counters + // Values names counters (global fallback for non-Fusion contexts) std::unordered_map val_type_name_map_; - // Expression names counter + // Expression names counter (global fallback for non-Fusion contexts) StmtNameType expr_name_counter_ = 0; - // Manually store some persistent, frequently used nodes. It's very - // challenging to do this anything but manually as detecting when a container - // may or may not have one of these vals is tricky. Specifically because if - // the container doesn't own it, it's hard to understand from the outside if - // the node may have been removed then re-registered. It could also be tricky - // to know when we're using a different container as in FusionCopy_test - // demonstrates deleting then creating containers can result in the same - // pointer for the container. - std::unique_ptr true_val_; - std::unique_ptr false_val_; - std::unique_ptr one_val_; - std::unique_ptr zero_val_; - std::unique_ptr magic_zero_val_; - std::unique_ptr> axioms_; - std::unordered_map> metadata_; + // Per-Fusion name counters (Phase 2 Task 10) + // Each Fusion gets its own counter starting at 0, so cloned Fusions + // produce matching names (T0=T0, T1=T1) instead of incrementing names. + // This is critical for GreedyParams and normalization_utils which use + // tv->name() as map keys across cloned Fusions. + std::unordered_map> + per_fusion_val_name_map_; + std::unordered_map per_fusion_expr_name_counter_; + + // Note: Special values (zero_val_, one_val_, true_val_, false_val_, + // magic_zero_val_) are now per-Fusion, stored in Fusion class. + // This avoids ownership conflicts when multiple Fusions share an IrContainer. + // See Fusion::zeroVal(), Fusion::axioms(), Fusion::metadataOf(), etc. + // for the per-Fusion implementations. public: Fusion* parent() const { diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp new file mode 100644 index 00000000000..4261ec6e501 --- /dev/null +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -0,0 +1,1819 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include +#include +#include + +#include "fusion.h" +#include "ir/container.h" +#include "ops/all_ops.h" +#include "statement_guard.h" +#include "tests/cpp/utils.h" + +namespace nvfuser { + +// Test class for Phase 2 container sharing tests +class Phase2ContainerTest : public NVFuserTest { + protected: + void SetUp() override { + NVFuserTest::SetUp(); + } + void TearDown() override { + NVFuserTest::TearDown(); + } +}; + +// ============================================================================= +// Task 1 Tests: Locking Infrastructure +// ============================================================================= + +TEST_F(Phase2ContainerTest, LockingBasic) { + // Verify basic operations still work with locking in place + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Verify container has expected contents + // Use vals() and unordered_exprs() which return references to container data + EXPECT_GT(fusion.vals().size(), 0); + EXPECT_GT(fusion.unordered_exprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, ConcurrentReads) { + // Multiple threads can read simultaneously without data races + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + std::vector threads; + std::atomic read_count{0}; + + // Spawn multiple reader threads + for (int i = 0; i < 4; ++i) { + threads.emplace_back([&]() { + for (int j = 0; j < 100; ++j) { + // Access vals and unordered_exprs through fusion's forwarding methods + // These return const references to the underlying container data + const auto& vals = fusion.vals(); + const auto& exprs = fusion.unordered_exprs(); + // Just access sizes to verify no crashes under concurrent access + (void)vals.size(); + (void)exprs.size(); + read_count++; + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + EXPECT_EQ(read_count.load(), 400); +} + +// ============================================================================= +// Task 2 Tests: Fusion Tracking Infrastructure +// ============================================================================= + +TEST_F(Phase2ContainerTest, FusionRegistration) { + // Test that addFusion increments count, removeFusion decrements + Fusion fusion; + FusionGuard fg(&fusion); + + // Get the IrContainer through Fusion + auto& container = *fusion.ir_container(); + + // Initially no Fusions registered (Phase 1 doesn't use registration yet) + EXPECT_EQ(container.sharingCount(), 0); + + // Register the Fusion + container.addFusion(&fusion); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_FALSE(container.hasMultipleFusions()); + + // Create another Fusion and register it with the same container + // (simulating shared_ptr sharing that will happen in later tasks) + Fusion fusion2; + container.addFusion(&fusion2); + EXPECT_EQ(container.sharingCount(), 2); + EXPECT_TRUE(container.hasMultipleFusions()); + + // Remove one + container.removeFusion(&fusion2); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_FALSE(container.hasMultipleFusions()); + + // Remove the other + container.removeFusion(&fusion); + EXPECT_EQ(container.sharingCount(), 0); +} + +TEST_F(Phase2ContainerTest, FusionTransfer) { + // Test transferFusion correctly updates tracking + Fusion fusion1; + Fusion fusion2; + + auto& container = *fusion1.ir_container(); + + // Register fusion1 + container.addFusion(&fusion1); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_TRUE(container.sharingFusions().count(&fusion1) > 0); + EXPECT_TRUE(container.sharingFusions().count(&fusion2) == 0); + + // Transfer from fusion1 to fusion2 + container.transferFusion(&fusion1, &fusion2); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_TRUE(container.sharingFusions().count(&fusion1) == 0); + EXPECT_TRUE(container.sharingFusions().count(&fusion2) > 0); +} + +TEST_F(Phase2ContainerTest, MultipleRegistration) { + // Test multiple Fusions can register with same container + Fusion fusion1; + Fusion fusion2; + Fusion fusion3; + + auto& container = *fusion1.ir_container(); + + container.addFusion(&fusion1); + container.addFusion(&fusion2); + container.addFusion(&fusion3); + + EXPECT_EQ(container.sharingCount(), 3); + EXPECT_TRUE(container.hasMultipleFusions()); + + // Verify all are registered + const auto& fusions = container.sharingFusions(); + EXPECT_TRUE(fusions.count(&fusion1) > 0); + EXPECT_TRUE(fusions.count(&fusion2) > 0); + EXPECT_TRUE(fusions.count(&fusion3) > 0); +} + +TEST_F(Phase2ContainerTest, StatementCleanup) { + // Test that removeFusion removes only Statements owned by that Fusion + // This is tricky to test directly because Statements are tied to their + // container at construction. We test the basic mechanism works. + + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + size_t initial_vals = container.vals().size(); + size_t initial_exprs = container.unordered_exprs().size(); + + EXPECT_GT(initial_vals, 0); + EXPECT_GT(initial_exprs, 0); + + // Register fusion + container.addFusion(&fusion); + + // When we remove fusion, its Statements should be cleaned up + // (all Statements in this test are owned by fusion) + container.removeFusion(&fusion); + + // After removal, the Statements owned by fusion should be removed + EXPECT_EQ(container.vals().size(), 0); + EXPECT_EQ(container.unordered_exprs().size(), 0); +} + +// ============================================================================= +// Task 4 Tests: Per-Fusion Statement Tracking +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionValsTracking) { + // Test that ownedVals() returns only this Fusion's vals + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // ownedVals() should return only this Fusion's vals + const auto& owned_vals = fusion.ownedVals(); + EXPECT_GT(owned_vals.size(), 0); + + // All vals in ownedVals() should have container() == &fusion + for (auto* val : owned_vals) { + EXPECT_EQ(val->container(), &fusion); + } + + // vals() and ownedVals() should be the same with a single Fusion (Phase 1 + // equivalence) + EXPECT_EQ(fusion.vals().size(), fusion.ownedVals().size()); +} + +TEST_F(Phase2ContainerTest, PerFusionExprsTracking) { + // Test that ownedExprs() returns only this Fusion's exprs + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // ownedExprs() should return only this Fusion's exprs + const auto& owned_exprs = fusion.ownedExprs(); + EXPECT_GT(owned_exprs.size(), 0); + + // All exprs in ownedExprs() should have container() == &fusion + for (auto* expr : owned_exprs) { + EXPECT_EQ(expr->container(), &fusion); + } + + // unordered_exprs() and ownedExprs() should be the same with a single Fusion + EXPECT_EQ(fusion.unordered_exprs().size(), fusion.ownedExprs().size()); +} + +TEST_F(Phase2ContainerTest, ValsOwnedByAPI) { + // Test IrContainer::valsOwnedBy() API directly + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // valsOwnedBy should return same set as ownedVals() + const auto& vals_by_container = container.valsOwnedBy(&fusion); + const auto& vals_by_fusion = fusion.ownedVals(); + EXPECT_EQ(vals_by_container.size(), vals_by_fusion.size()); + + // valsOwnedBy for a non-registered Fusion should return empty set + Fusion other_fusion; + const auto& other_vals = container.valsOwnedBy(&other_fusion); + EXPECT_EQ(other_vals.size(), 0); +} + +TEST_F(Phase2ContainerTest, ExprsOwnedByAPI) { + // Test IrContainer::exprsOwnedBy() API directly + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // exprsOwnedBy should return same set as ownedExprs() + const auto& exprs_by_container = container.exprsOwnedBy(&fusion); + const auto& exprs_by_fusion = fusion.ownedExprs(); + EXPECT_EQ(exprs_by_container.size(), exprs_by_fusion.size()); + + // exprsOwnedBy for a non-registered Fusion should return empty set + Fusion other_fusion; + const auto& other_exprs = container.exprsOwnedBy(&other_fusion); + EXPECT_EQ(other_exprs.size(), 0); +} + +TEST_F(Phase2ContainerTest, RegisterUpdatesPerFusionTracking) { + // Test that registering new vals/exprs updates per-Fusion tracking + Fusion fusion; + FusionGuard fg(&fusion); + + // Initially no vals + EXPECT_EQ(fusion.ownedVals().size(), 0); + EXPECT_EQ(fusion.ownedExprs().size(), 0); + + // Add an input - this creates vals + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // Now we should have vals tracked for this fusion + size_t vals_after_input = fusion.ownedVals().size(); + EXPECT_GT(vals_after_input, 0); + + // Add an expression - this creates more vals and exprs + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Both should have grown + EXPECT_GT(fusion.ownedVals().size(), vals_after_input); + EXPECT_GT(fusion.ownedExprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, TransferStatementOwnership) { + // Test IrContainer::transferStatementOwnership + auto container = std::make_shared(); + + // Create dummy Fusions for testing + Fusion fusion1; + Fusion fusion2; + + // We can't easily create vals owned by fusion1 in a standalone container, + // but we can test the tracking data structure directly + container->addFusion(&fusion1); + container->addFusion(&fusion2); + + // Transfer ownership - should not crash even with empty tracking + container->transferStatementOwnership(&fusion1, &fusion2); + + // Verify fusion1 no longer has tracking entries (empty case) + EXPECT_EQ(container->valsOwnedBy(&fusion1).size(), 0); + EXPECT_EQ(container->exprsOwnedBy(&fusion1).size(), 0); + + // Cleanup + container->removeFusion(&fusion1); + container->removeFusion(&fusion2); +} + +TEST_F(Phase2ContainerTest, ClearOnlyAffectsOwnedStatements) { + // Test that Fusion::clear() only clears THIS Fusion's statements + // This is critical for shared container correctness + + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Get container reference + auto container_ptr = fusion.ir_container_ptr(); + + // Record counts before clear + size_t vals_before = fusion.ownedVals().size(); + size_t exprs_before = fusion.ownedExprs().size(); + EXPECT_GT(vals_before, 0); + EXPECT_GT(exprs_before, 0); + + // Clear the fusion + fusion.clear(); + + // After clear, ownedVals/ownedExprs should be empty for this fusion + EXPECT_EQ(fusion.ownedVals().size(), 0); + EXPECT_EQ(fusion.ownedExprs().size(), 0); + + // Container-level accessors should also reflect the removal + EXPECT_EQ(container_ptr->vals().size(), 0); + EXPECT_EQ(container_ptr->unordered_exprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, RemoveStatementsOwnedByAPI) { + // Test public IrContainer::removeStatementsOwnedBy API + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // Verify we have statements + EXPECT_GT(container.vals().size(), 0); + EXPECT_GT(container.unordered_exprs().size(), 0); + EXPECT_GT(container.valsOwnedBy(&fusion).size(), 0); + EXPECT_GT(container.exprsOwnedBy(&fusion).size(), 0); + + // Clear fusion-level state first (inputs_, outputs_, etc.) + // Note: We're testing the container API directly, not through Fusion::clear() + // In practice, Fusion::clear() does both + container.removeStatementsOwnedBy(&fusion); + + // After removal, tracking should be empty + EXPECT_EQ(container.valsOwnedBy(&fusion).size(), 0); + EXPECT_EQ(container.exprsOwnedBy(&fusion).size(), 0); + + // Container-level sets should also be empty (single fusion case) + EXPECT_EQ(container.vals().size(), 0); + EXPECT_EQ(container.unordered_exprs().size(), 0); +} + +// ============================================================================= +// Task 7 Tests: Per-Fusion Special Values +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionSpecialValuesBasic) { + // Test that special values are created per-Fusion + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + EXPECT_NE(zero_a, nullptr); + EXPECT_NE(one_a, nullptr); + EXPECT_EQ(zero_a->container(), &a); + EXPECT_EQ(one_a->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesOwnedByFusion) { + // Test that special values are tracked in ownedVals + Fusion a; + FusionGuard fg_a(&a); + + Val* zero_a = a.zeroVal(); + + // Special values should be in ownedVals + EXPECT_TRUE(a.ownedVals().count(zero_a) > 0); +} + +TEST_F(Phase2ContainerTest, SeparateFusionsHaveOwnSpecialValues) { + // Two independent Fusions should have different special values + Fusion a; + Fusion b; + + { + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + EXPECT_EQ(zero_a->container(), &a); + } + + { + FusionGuard fg_b(&b); + Val* zero_b = b.zeroVal(); + EXPECT_EQ(zero_b->container(), &b); + } + + // Each has its own zero (different objects) + EXPECT_NE(a.zeroVal(), b.zeroVal()); +} + +TEST_F(Phase2ContainerTest, DestroyFusionDoesNotAffectOther) { + // Destroying one Fusion should not affect another's special values + Fusion a; + FusionGuard fg_a(&a); + + // Create special values in a + Val* zero_a = a.zeroVal(); + EXPECT_NE(zero_a, nullptr); + + { + Fusion b; + FusionGuard fg_b(&b); + Val* zero_b = b.zeroVal(); + EXPECT_NE(zero_b, nullptr); + // b destroyed here + } + + // a should still work fine - its special values should still be valid + Val* zero_a_again = a.zeroVal(); + EXPECT_EQ(zero_a_again, zero_a); + EXPECT_EQ(zero_a_again->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesLazyCreation) { + // Special values should be created lazily + Fusion a; + FusionGuard fg_a(&a); + + // Before calling zeroVal(), it shouldn't exist + // (Can't directly test this, but we can verify it works after call) + Val* zero1 = a.zeroVal(); + Val* zero2 = a.zeroVal(); + + // Same value returned on repeated calls + EXPECT_EQ(zero1, zero2); +} + +TEST_F(Phase2ContainerTest, AllSpecialValuesPerFusion) { + // Test all special value accessors + Fusion a; + FusionGuard fg_a(&a); + + Val* zero = a.zeroVal(); + Val* one = a.oneVal(); + Val* true_val = a.trueVal(); + Val* false_val = a.falseVal(); + NamedScalar* magic_zero = a.magicZeroVal(); + + // All should be non-null + EXPECT_NE(zero, nullptr); + EXPECT_NE(one, nullptr); + EXPECT_NE(true_val, nullptr); + EXPECT_NE(false_val, nullptr); + EXPECT_NE(magic_zero, nullptr); + + // All should have container() == &a + EXPECT_EQ(zero->container(), &a); + EXPECT_EQ(one->container(), &a); + EXPECT_EQ(true_val->container(), &a); + EXPECT_EQ(false_val->container(), &a); + EXPECT_EQ(magic_zero->container(), &a); + + // All should be tracked in ownedVals + EXPECT_TRUE(a.ownedVals().count(zero) > 0); + EXPECT_TRUE(a.ownedVals().count(one) > 0); + EXPECT_TRUE(a.ownedVals().count(true_val) > 0); + EXPECT_TRUE(a.ownedVals().count(false_val) > 0); + EXPECT_TRUE(a.ownedVals().count(magic_zero) > 0); +} + +TEST_F(Phase2ContainerTest, SpecialValuesClearedOnFusionClear) { + // Test that Fusion::clear() resets special values + Fusion a; + FusionGuard fg_a(&a); + + // Create special values + Val* zero_before = a.zeroVal(); + Val* one_before = a.oneVal(); + EXPECT_NE(zero_before, nullptr); + EXPECT_NE(one_before, nullptr); + + // Clear the fusion + a.clear(); + + // Special values should be recreated lazily (new objects) + Val* zero_after = a.zeroVal(); + Val* one_after = a.oneVal(); + + // The new objects should be different from the old ones + // (old ones were removed by removeStatementsOwnedBy) + EXPECT_NE(zero_after, zero_before); + EXPECT_NE(one_after, one_before); + + // New objects should be valid and owned by the fusion + EXPECT_EQ(zero_after->container(), &a); + EXPECT_EQ(one_after->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesWithDtype) { + // Test zeroVal(dtype) and oneVal(dtype) accessors + Fusion a; + FusionGuard fg_a(&a); + + // Index type should return the cached value + Val* zero_index = a.zeroVal(DataType::Index); + Val* zero_cached = a.zeroVal(); + EXPECT_EQ(zero_index, zero_cached); + + Val* one_index = a.oneVal(DataType::Index); + Val* one_cached = a.oneVal(); + EXPECT_EQ(one_index, one_cached); + + // Bool type should return true/false val + Val* zero_bool = a.zeroVal(DataType::Bool); + Val* false_cached = a.falseVal(); + EXPECT_EQ(zero_bool, false_cached); + + Val* one_bool = a.oneVal(DataType::Bool); + Val* true_cached = a.trueVal(); + EXPECT_EQ(one_bool, true_cached); + + // Other types should create new values (not cached) + Val* zero_float = a.zeroVal(DataType::Float); + Val* zero_float2 = a.zeroVal(DataType::Float); + // These are not cached, so they're different objects + EXPECT_NE(zero_float, zero_float2); +} + +// ============================================================================= +// Task 5 Tests: Copy Semantics with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, CopySharesContainer) { + // After copy, both Fusions point to the same container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy + + // Both should share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); +} + +TEST_F(Phase2ContainerTest, CopyRegistersWithContainer) { + // sharingCount should increment after copy + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + EXPECT_EQ(a.ir_container()->sharingCount(), 1); + + Fusion b(a); + + EXPECT_EQ(a.ir_container()->sharingCount(), 2); + EXPECT_EQ(b.ir_container()->sharingCount(), 2); +} + +TEST_F(Phase2ContainerTest, CopiedNodesOwnedByNewFusion) { + // Cloned nodes should have container() == © + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + + // b should have inputs + EXPECT_EQ(b.inputs().size(), 1); + + // b's input should be owned by b (not a) + EXPECT_EQ(b.inputs()[0]->container(), &b); + + // b's input should be different from a's input (cloned) + EXPECT_NE(b.inputs()[0], a.inputs()[0]); +} + +TEST_F(Phase2ContainerTest, CopyOwnedValsAreIndependent) { + // a's ownedVals and b's ownedVals should be disjoint + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + + // All of a's ownedVals should have container() == &a + for (auto* v : a.ownedVals()) { + EXPECT_EQ(v->container(), &a); + } + + // All of b's ownedVals should have container() == &b + for (auto* v : b.ownedVals()) { + EXPECT_EQ(v->container(), &b); + } + + // The sets should be disjoint + for (auto* v : a.ownedVals()) { + EXPECT_EQ(b.ownedVals().count(v), 0); + } +} + +TEST_F(Phase2ContainerTest, DestructorOnlyRemovesOwnedStatements) { + // Destroying copy should not affect original's statements + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + size_t a_vals_before = a.ownedVals().size(); + + { + Fusion b(a); // Copy + // b gets its own cloned nodes + EXPECT_GT(b.ownedVals().size(), 0); + // b destroyed here + } + + // a's vals should still exist and be unchanged + EXPECT_EQ(a.ownedVals().size(), a_vals_before); + + // a's vals should still have correct container + for (auto* v : a.ownedVals()) { + EXPECT_EQ(v->container(), &a); + } +} + +TEST_F(Phase2ContainerTest, CopyHasOwnSpecialValues) { + // Each Fusion (original and copy) should have its own special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + Fusion b(a); // Copy + + // Copy should have its own special values + Val* zero_b = b.zeroVal(); + Val* one_b = b.oneVal(); + + // Different objects + EXPECT_NE(zero_a, zero_b); + EXPECT_NE(one_a, one_b); + + // Correct ownership + EXPECT_EQ(zero_a->container(), &a); + EXPECT_EQ(zero_b->container(), &b); +} + +TEST_F(Phase2ContainerTest, CopySpecialValuesIndependent) { + // Destroying copy should not affect original's special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + + { + Fusion b(a); // Copy + Val* zero_b = b.zeroVal(); + EXPECT_NE(zero_a, zero_b); + // b destroyed here + } + + // a's special values should still be valid + EXPECT_EQ(a.zeroVal(), zero_a); + EXPECT_EQ(zero_a->container(), &a); +} + +TEST_F(Phase2ContainerTest, CopySharingCountDecrementsOnDestruction) { + // When copy is destroyed, sharingCount should decrement + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto container_ptr = a.ir_container_ptr(); + EXPECT_EQ(container_ptr->sharingCount(), 1); + + { + Fusion b(a); + EXPECT_EQ(container_ptr->sharingCount(), 2); + // b destroyed here + } + + EXPECT_EQ(container_ptr->sharingCount(), 1); +} + +TEST_F(Phase2ContainerTest, CopyReturnsIrCloner) { + // Fusion::copy should return IrCloner for node mapping + // We test this indirectly via the copy constructor which uses Fusion::copy + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // Copy constructor uses Fusion::copy internally + Fusion b(a); + + // Verify the copy worked - b has cloned inputs/outputs + EXPECT_EQ(b.inputs().size(), a.inputs().size()); + EXPECT_EQ(b.outputs().size(), a.outputs().size()); + + // Cloned nodes should belong to b + EXPECT_EQ(b.inputs()[0]->container(), &b); + EXPECT_EQ(b.outputs()[0]->container(), &b); + + // They should be different objects from a's nodes + EXPECT_NE(b.inputs()[0], a.inputs()[0]); + EXPECT_NE(b.outputs()[0], a.outputs()[0]); +} + +// ============================================================================= +// Task 6 Tests: Move Semantics with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, MoveConstructorTransfersOwnership) { + // Move constructor should transfer container ownership + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto* container = a.ir_container_ptr().get(); + size_t a_vals_count = a.ownedVals().size(); + + Fusion b(std::move(a)); + + // b should have a's old container + EXPECT_EQ(b.ir_container_ptr().get(), container); + + // b should have a's statements + EXPECT_EQ(b.ownedVals().size(), a_vals_count); +} + +TEST_F(Phase2ContainerTest, MoveConstructorSourceIsValid) { + // After move, source should be valid with new empty container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + Fusion b(std::move(a)); + + // Source has new empty container (not nullptr) + EXPECT_NE(a.ir_container_ptr().get(), nullptr); + EXPECT_NE(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // Source is empty + EXPECT_EQ(a.ownedVals().size(), 0); + EXPECT_EQ(a.inputs().size(), 0); + EXPECT_EQ(a.outputs().size(), 0); + + // Source can still be used safely + a.clear(); // Should not crash +} + +TEST_F(Phase2ContainerTest, MoveUpdatesStatementOwnership) { + // Moved statements should have container() pointing to destination + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // Capture original vals + std::vector orig_vals(a.ownedVals().begin(), a.ownedVals().end()); + EXPECT_GT(orig_vals.size(), 0); + + Fusion b(std::move(a)); + + // All original vals now belong to b + for (auto* val : orig_vals) { + EXPECT_EQ(val->container(), &b); + } + + // b's ownedVals should contain them + for (auto* val : orig_vals) { + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MoveTransfersSpecialValues) { + // Move should transfer special value pointers to destination + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + Fusion b(std::move(a)); + + // b should have a's special values + EXPECT_EQ(b.zeroVal(), zero_a); + EXPECT_EQ(b.oneVal(), one_a); + + // Ownership updated to b + EXPECT_EQ(zero_a->container(), &b); + EXPECT_EQ(one_a->container(), &b); +} + +TEST_F(Phase2ContainerTest, MoveSourceCanCreateNewSpecialValues) { + // After move, source can create new special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + + Fusion b(std::move(a)); + + // a is now empty but valid - can create new special values + Val* zero_a_new = a.zeroVal(); + + // Different from the moved one + EXPECT_NE(zero_a_new, zero_a); + EXPECT_EQ(zero_a_new->container(), &a); +} + +TEST_F(Phase2ContainerTest, MoveAssignmentWorks) { + // Move assignment should transfer ownership + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + + Fusion b; + b = std::move(a); + + // b has a's container + EXPECT_EQ(b.ir_container_ptr().get(), container); + + // a is valid but empty + EXPECT_NE(a.ir_container_ptr().get(), nullptr); + EXPECT_EQ(a.ownedVals().size(), 0); +} + +TEST_F(Phase2ContainerTest, MoveAssignmentSelfAssignment) { + // Self-assignment should be a no-op + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + size_t vals_count = a.ownedVals().size(); + + // Use a reference to avoid -Wself-move warning + Fusion& a_ref = a; + a = std::move(a_ref); + + // Should be unchanged + EXPECT_EQ(a.ir_container_ptr().get(), container); + EXPECT_EQ(a.ownedVals().size(), vals_count); +} + +TEST_F(Phase2ContainerTest, SwapExchangesContainers) { + // Swap should exchange container pointers + Fusion a, b; + + auto* container_a = a.ir_container_ptr().get(); + auto* container_b = b.ir_container_ptr().get(); + + Fusion::swap(a, b); + + EXPECT_EQ(a.ir_container_ptr().get(), container_b); + EXPECT_EQ(b.ir_container_ptr().get(), container_a); +} + +TEST_F(Phase2ContainerTest, SwapUpdatesStatementOwnership) { + // Swap should exchange statement ownership + Fusion a, b; + + { + FusionGuard fg_a(&a); + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + } + + { + FusionGuard fg_b(&b); + auto* tv0 = makeSymbolicTensor(3); + b.addInput(tv0); + } + + // Capture original vals + std::vector a_vals(a.ownedVals().begin(), a.ownedVals().end()); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + Fusion::swap(a, b); + + // a's old vals now belong to b + for (auto* val : a_vals) { + EXPECT_EQ(val->container(), &b); + } + + // b's old vals now belong to a + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &a); + } +} + +TEST_F(Phase2ContainerTest, SwapExchangesSpecialValues) { + // Swap should exchange special values + Fusion a, b; + + Val* zero_a = nullptr; + Val* zero_b = nullptr; + + { + FusionGuard fg_a(&a); + zero_a = a.zeroVal(); + } + + { + FusionGuard fg_b(&b); + zero_b = b.zeroVal(); + } + + Fusion::swap(a, b); + + // Special values exchanged + EXPECT_EQ(a.zeroVal(), zero_b); + EXPECT_EQ(b.zeroVal(), zero_a); + + // Ownership updated + EXPECT_EQ(zero_a->container(), &b); + EXPECT_EQ(zero_b->container(), &a); +} + +TEST_F(Phase2ContainerTest, SwapSelfSwapIsNoop) { + // Swapping with self should be a no-op + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + size_t vals_count = a.ownedVals().size(); + + Fusion::swap(a, a); + + EXPECT_EQ(a.ir_container_ptr().get(), container); + EXPECT_EQ(a.ownedVals().size(), vals_count); +} + +TEST_F(Phase2ContainerTest, MoveFromCopyPreservesOther) { + // If we copy A to B (sharing container), then move A to C, + // B should be unaffected + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Capture b's state + size_t b_vals_before = b.ownedVals().size(); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + Fusion c(std::move(a)); // Move a to c + + // b should be completely unaffected + EXPECT_EQ(b.ownedVals().size(), b_vals_before); + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &b); + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MoveFromCopyTransfersCorrectly) { + // If we copy A to B, then move A to C, + // C should have A's original statements + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + // Capture a's vals before copy + std::vector a_vals(a.ownedVals().begin(), a.ownedVals().end()); + + Fusion b(a); // Copy + Fusion c(std::move(a)); // Move a to c + + // c should have a's original vals + for (auto* val : a_vals) { + EXPECT_EQ(val->container(), &c); + EXPECT_TRUE(c.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MovePreservesInputsOutputs) { + // Move should transfer inputs/outputs vectors + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Val* orig_input = a.inputs()[0]; + Val* orig_output = a.outputs()[0]; + + Fusion b(std::move(a)); + + // b has the inputs/outputs + EXPECT_EQ(b.inputs().size(), 1); + EXPECT_EQ(b.outputs().size(), 1); + EXPECT_EQ(b.inputs()[0], orig_input); + EXPECT_EQ(b.outputs()[0], orig_output); + + // a is empty + EXPECT_EQ(a.inputs().size(), 0); + EXPECT_EQ(a.outputs().size(), 0); +} + +// ============================================================================= +// Deterministic Accessor Tests: Per-Fusion Filtering +// ============================================================================= + +TEST_F(Phase2ContainerTest, DeterministicValsReturnsOnlyOwned) { + // With a single Fusion, deterministic_vals() should return all vals + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // deterministic_vals() should return same count as ownedVals() + auto det_vals = a.deterministic_vals(); + EXPECT_EQ(det_vals.size(), a.ownedVals().size()); + + // All vals in deterministic_vals should be owned by a + for (auto* val : det_vals) { + EXPECT_EQ(val->container(), &a); + EXPECT_TRUE(a.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsReturnsOnlyOwned) { + // With a single Fusion, deterministic_exprs() should return all exprs + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // deterministic_exprs() should return same count as ownedExprs() + auto det_exprs = a.deterministic_exprs(); + EXPECT_EQ(det_exprs.size(), a.ownedExprs().size()); + + // All exprs in deterministic_exprs should be owned by a + for (auto* expr : det_exprs) { + EXPECT_EQ(expr->container(), &a); + EXPECT_TRUE(a.ownedExprs().count(expr) > 0); + } +} + +TEST_F( + Phase2ContainerTest, + DeterministicValsFiltersByOwnershipInSharedContainer) { + // After copy, each Fusion's deterministic_vals() returns only ITS vals + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Both share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // But each has its own deterministic vals + auto a_det_vals = a.deterministic_vals(); + auto b_det_vals = b.deterministic_vals(); + + // Sizes should match ownedVals + EXPECT_EQ(a_det_vals.size(), a.ownedVals().size()); + EXPECT_EQ(b_det_vals.size(), b.ownedVals().size()); + + // a's deterministic_vals should all be owned by a + for (auto* val : a_det_vals) { + EXPECT_EQ(val->container(), &a); + } + + // b's deterministic_vals should all be owned by b + for (auto* val : b_det_vals) { + EXPECT_EQ(val->container(), &b); + } + + // The sets should be disjoint + std::unordered_set a_set(a_det_vals.begin(), a_det_vals.end()); + for (auto* val : b_det_vals) { + EXPECT_EQ(a_set.count(val), 0); + } +} + +TEST_F( + Phase2ContainerTest, + DeterministicExprsFiltersByOwnershipInSharedContainer) { + // After copy, each Fusion's deterministic_exprs() returns only ITS exprs + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Both share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // But each has its own deterministic exprs + auto a_det_exprs = a.deterministic_exprs(); + auto b_det_exprs = b.deterministic_exprs(); + + // Sizes should match ownedExprs + EXPECT_EQ(a_det_exprs.size(), a.ownedExprs().size()); + EXPECT_EQ(b_det_exprs.size(), b.ownedExprs().size()); + + // a's deterministic_exprs should all be owned by a + for (auto* expr : a_det_exprs) { + EXPECT_EQ(expr->container(), &a); + } + + // b's deterministic_exprs should all be owned by b + for (auto* expr : b_det_exprs) { + EXPECT_EQ(expr->container(), &b); + } + + // The sets should be disjoint + std::unordered_set a_set(a_det_exprs.begin(), a_det_exprs.end()); + for (auto* expr : b_det_exprs) { + EXPECT_EQ(a_set.count(expr), 0); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsMapFiltersByOwnership) { + // deterministic_vals_map should only include owned vals with local indices + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + auto a_map = a.deterministic_vals_map(); + auto b_map = b.deterministic_vals_map(); + + // Maps should have same size as ownedVals + EXPECT_EQ(a_map.size(), a.ownedVals().size()); + EXPECT_EQ(b_map.size(), b.ownedVals().size()); + + // All keys in a_map should be owned by a + for (const auto& [val, idx] : a_map) { + EXPECT_EQ(val->container(), &a); + } + + // All keys in b_map should be owned by b + for (const auto& [val, idx] : b_map) { + EXPECT_EQ(val->container(), &b); + } + + // Indices should be sequential starting from 0 (local to each Fusion) + std::vector a_indices, b_indices; + for (const auto& [val, idx] : a_map) { + a_indices.push_back(idx); + } + for (const auto& [val, idx] : b_map) { + b_indices.push_back(idx); + } + + std::sort(a_indices.begin(), a_indices.end()); + std::sort(b_indices.begin(), b_indices.end()); + + // Should be 0, 1, 2, ... for each + for (size_t i = 0; i < a_indices.size(); ++i) { + EXPECT_EQ(a_indices[i], static_cast(i)); + } + for (size_t i = 0; i < b_indices.size(); ++i) { + EXPECT_EQ(b_indices[i], static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsMapFiltersByOwnership) { + // deterministic_exprs_map should only include owned exprs with local indices + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + auto a_map = a.deterministic_exprs_map(); + auto b_map = b.deterministic_exprs_map(); + + // Maps should have same size as ownedExprs + EXPECT_EQ(a_map.size(), a.ownedExprs().size()); + EXPECT_EQ(b_map.size(), b.ownedExprs().size()); + + // All keys in a_map should be owned by a + for (const auto& [expr, idx] : a_map) { + EXPECT_EQ(expr->container(), &a); + } + + // All keys in b_map should be owned by b + for (const auto& [expr, idx] : b_map) { + EXPECT_EQ(expr->container(), &b); + } + + // Indices should be sequential starting from 0 + std::vector a_indices, b_indices; + for (const auto& [expr, idx] : a_map) { + a_indices.push_back(idx); + } + for (const auto& [expr, idx] : b_map) { + b_indices.push_back(idx); + } + + std::sort(a_indices.begin(), a_indices.end()); + std::sort(b_indices.begin(), b_indices.end()); + + for (size_t i = 0; i < a_indices.size(); ++i) { + EXPECT_EQ(a_indices[i], static_cast(i)); + } + for (size_t i = 0; i < b_indices.size(); ++i) { + EXPECT_EQ(b_indices[i], static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsMaintainsInsertionOrder) { + // deterministic_vals should maintain insertion order + Fusion a; + FusionGuard fg_a(&a); + + // Create multiple tensors in specific order + auto* tv0 = makeSymbolicTensor(1); + a.addInput(tv0); + auto* tv1 = makeSymbolicTensor(2); + a.addInput(tv1); + auto* tv2 = add(tv0, tv0); + auto* tv3 = add(tv1, tv1); + a.addOutput(tv2); + a.addOutput(tv3); + + auto det_vals = a.deterministic_vals(); + auto det_map = a.deterministic_vals_map(); + + // Verify deque order matches map indices + for (size_t i = 0; i < det_vals.size(); ++i) { + Val* val = det_vals[i]; + EXPECT_EQ(det_map.at(val), static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsMaintainsInsertionOrder) { + // deterministic_exprs should maintain insertion order + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv0); + auto* tv3 = sub(tv2, tv1); + a.addOutput(tv3); + + auto det_exprs = a.deterministic_exprs(); + auto det_map = a.deterministic_exprs_map(); + + // Verify deque order matches map indices + for (size_t i = 0; i < det_exprs.size(); ++i) { + Expr* expr = det_exprs[i]; + EXPECT_EQ(det_map.at(expr), static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicAccessorsAfterCopyPreservesOrder) { + // After copy, deterministic order for each Fusion should be correct + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Capture a's deterministic vals + auto a_det_vals_before = a.deterministic_vals(); + auto a_det_map_before = a.deterministic_vals_map(); + + Fusion b(a); // Copy + + // a's deterministic_vals should be unchanged + auto a_det_vals_after = a.deterministic_vals(); + EXPECT_EQ(a_det_vals_before.size(), a_det_vals_after.size()); + for (size_t i = 0; i < a_det_vals_before.size(); ++i) { + EXPECT_EQ(a_det_vals_before[i], a_det_vals_after[i]); + } + + // b's deterministic_vals should have same structure (but different objects) + auto b_det_vals = b.deterministic_vals(); + EXPECT_EQ(b_det_vals.size(), a_det_vals_before.size()); +} + +TEST_F(Phase2ContainerTest, DeterministicAccessorsAfterDestroyingCopy) { + // After destroying a copy, original's deterministic accessors still work + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto a_det_vals_before = a.deterministic_vals(); + auto a_det_map_before = a.deterministic_vals_map(); + + { + Fusion b(a); // Copy + // b destroyed here + } + + // a's deterministic accessors should still work correctly + auto a_det_vals_after = a.deterministic_vals(); + auto a_det_map_after = a.deterministic_vals_map(); + + EXPECT_EQ(a_det_vals_before.size(), a_det_vals_after.size()); + EXPECT_EQ(a_det_map_before.size(), a_det_map_after.size()); + + // Same values, same order + for (size_t i = 0; i < a_det_vals_before.size(); ++i) { + EXPECT_EQ(a_det_vals_before[i], a_det_vals_after[i]); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsEmptyForNewFusion) { + // New empty Fusion should have empty deterministic vals + Fusion a; + + auto det_vals = a.deterministic_vals(); + auto det_exprs = a.deterministic_exprs(); + auto det_vals_map = a.deterministic_vals_map(); + auto det_exprs_map = a.deterministic_exprs_map(); + + EXPECT_EQ(det_vals.size(), 0); + EXPECT_EQ(det_exprs.size(), 0); + EXPECT_EQ(det_vals_map.size(), 0); + EXPECT_EQ(det_exprs_map.size(), 0); +} + +TEST_F(Phase2ContainerTest, DeterministicValsAfterClear) { + // After clear, deterministic vals should be empty + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + EXPECT_GT(a.deterministic_vals().size(), 0); + EXPECT_GT(a.deterministic_exprs().size(), 0); + + a.clear(); + + EXPECT_EQ(a.deterministic_vals().size(), 0); + EXPECT_EQ(a.deterministic_exprs().size(), 0); + EXPECT_EQ(a.deterministic_vals_map().size(), 0); + EXPECT_EQ(a.deterministic_exprs_map().size(), 0); +} + +// ============================================================================= +// StatementGuard Tests with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, StatementGuardWithSharedContainer) { + // Test that StatementGuard works correctly with shared containers + // Bug: StatementGuard uses per-Fusion counts but removes from container + // deques + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Capture b's state before StatementGuard on a + size_t b_vals_before = b.ownedVals().size(); + size_t b_exprs_before = b.ownedExprs().size(); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + { + FusionGuard fg_inner(&a); + StatementGuard sg(&a); + + // Create temporary vals in a + auto* temp = add(tv1, tv1); + (void)temp; + + // a has more vals now + EXPECT_GT(a.ownedVals().size(), b_vals_before); + } + // StatementGuard destructor should only remove a's new vals, not b's + + // b should be completely unaffected + EXPECT_EQ(b.ownedVals().size(), b_vals_before); + EXPECT_EQ(b.ownedExprs().size(), b_exprs_before); + + // b's vals should still have correct container + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &b); + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, StatementGuardDoesNotAffectOtherFusion) { + // StatementGuard on one Fusion should not affect another sharing the + // container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + size_t a_vals_before_copy = a.ownedVals().size(); + + Fusion b(a); // Copy - shares container + + // Both should have same number of vals (cloned) + EXPECT_EQ(a_vals_before_copy, b.ownedVals().size()); + + size_t b_vals_at_guard_start = 0; + size_t b_vals_in_guard = 0; + + // Use StatementGuard on b to create and remove temp statements + { + FusionGuard fg_b(&b); + StatementGuard sg(&b); + + // Note: StatementGuard constructor calls axioms() which may create + // additional vals. The snapshot is taken AFTER axioms initialization. + b_vals_at_guard_start = b.ownedVals().size(); + + // Create temp vals in b + auto* b_input = b.inputs()[0]->as(); + auto* temp = mul(b_input, b_input); + (void)temp; + + b_vals_in_guard = b.ownedVals().size(); + + // b should have more vals now (from the temp operation) + EXPECT_GT(b_vals_in_guard, b_vals_at_guard_start); + } + + size_t a_vals_after = a.ownedVals().size(); + size_t b_vals_after = b.ownedVals().size(); + + // After guard, a should be unchanged + EXPECT_EQ(a_vals_after, a_vals_before_copy); + + // b should be back to its state at guard construction time + // (which includes axioms but not the temp vals created inside the guard) + EXPECT_EQ(b_vals_after, b_vals_at_guard_start); +} + +// ============================================================================= +// Task 10 Tests: Per-Fusion Name Counters +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionNameCountersBasic) { + // Each Fusion gets its own name counters starting at 0 + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // First TensorView in Fusion a should have name 0 + EXPECT_EQ(tv0->name(), 0); +} + +TEST_F(Phase2ContainerTest, IndependentFusionsHaveOwnCounters) { + // Two independent Fusions both start name counters at 0 + Fusion a; + { + FusionGuard fg_a(&a); + auto* tv0_a = makeSymbolicTensor(2); + a.addInput(tv0_a); + auto* tv1_a = add(tv0_a, tv0_a); + a.addOutput(tv1_a); + // a's TensorViews should have names 0, 1 + EXPECT_EQ(tv0_a->name(), 0); + EXPECT_EQ(tv1_a->name(), 1); + } + + Fusion b; + { + FusionGuard fg_b(&b); + auto* tv0_b = makeSymbolicTensor(2); + b.addInput(tv0_b); + auto* tv1_b = add(tv0_b, tv0_b); + b.addOutput(tv1_b); + // b's TensorViews should ALSO have names 0, 1 (independent counter) + EXPECT_EQ(tv0_b->name(), 0); + EXPECT_EQ(tv1_b->name(), 1); + } +} + +TEST_F(Phase2ContainerTest, CopyNameCorrespondence) { + // CRITICAL: After Fusion::copy into shared container, cloned vals have + // matching names. This is required by GreedyParams and normalization_utils. + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Record a's val names + std::vector> a_val_names; + for (auto* val : a.deterministic_vals()) { + a_val_names.push_back({val, val->name()}); + } + + // Copy a -> b (shared container) + Fusion b(a); + + // b's vals should have MATCHING names (not incremented) + auto b_vals = b.deterministic_vals(); + EXPECT_EQ(b_vals.size(), a_val_names.size()); + + // Check TensorViews specifically - these are what GreedyParams uses + std::vector a_tvs, b_tvs; + for (auto* val : a.deterministic_vals()) { + if (val->isA()) { + a_tvs.push_back(val); + } + } + for (auto* val : b.deterministic_vals()) { + if (val->isA()) { + b_tvs.push_back(val); + } + } + + EXPECT_EQ(a_tvs.size(), b_tvs.size()); + for (size_t i = 0; i < a_tvs.size(); ++i) { + // Names should match across original and clone + EXPECT_EQ(a_tvs[i]->name(), b_tvs[i]->name()) + << "TV name mismatch at index " << i << ": a=" << a_tvs[i]->name() + << " b=" << b_tvs[i]->name(); + } +} + +TEST_F(Phase2ContainerTest, CopyExprNameCorrespondence) { + // After copy, cloned expressions should also have matching names + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Record a's expr names + auto a_exprs = a.deterministic_exprs(); + std::vector a_expr_names; + for (auto* expr : a_exprs) { + a_expr_names.push_back(expr->name()); + } + + // Copy + Fusion b(a); + + auto b_exprs = b.deterministic_exprs(); + EXPECT_EQ(b_exprs.size(), a_expr_names.size()); + + for (size_t i = 0; i < a_expr_names.size(); ++i) { + EXPECT_EQ(b_exprs[i]->name(), a_expr_names[i]) + << "Expr name mismatch at index " << i; + } +} + +TEST_F(Phase2ContainerTest, MultipleCopiesHaveMatchingNames) { + // Multiple copies from the same source should all have matching names + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + Fusion c(a); + + // b and c should both have matching TV names + auto a_tvs_det = a.deterministic_vals(); + auto b_tvs_det = b.deterministic_vals(); + auto c_tvs_det = c.deterministic_vals(); + + EXPECT_EQ(a_tvs_det.size(), b_tvs_det.size()); + EXPECT_EQ(a_tvs_det.size(), c_tvs_det.size()); + + for (size_t i = 0; i < a_tvs_det.size(); ++i) { + EXPECT_EQ(a_tvs_det[i]->name(), b_tvs_det[i]->name()); + EXPECT_EQ(a_tvs_det[i]->name(), c_tvs_det[i]->name()); + } +} + +TEST_F(Phase2ContainerTest, NameCountersCleanedUpOnDestroy) { + // When a Fusion is destroyed, its per-Fusion counters should be cleaned up + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto container_ptr = a.ir_container_ptr(); + + { + Fusion b(a); // Copy - shares container, creates per-Fusion counters + EXPECT_EQ(container_ptr->sharingCount(), 2); + } + // b destroyed: per-Fusion counters for b should be cleaned up + + EXPECT_EQ(container_ptr->sharingCount(), 1); + // a should still work fine + EXPECT_GT(a.ownedVals().size(), 0); +} + +TEST_F(Phase2ContainerTest, NameCountersSurviveSwap) { + // After swap, name counters should follow the data + Fusion a; + { + FusionGuard fg_a(&a); + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + } + + Fusion b; + { + FusionGuard fg_b(&b); + auto* tv0 = makeSymbolicTensor(3); + b.addInput(tv0); + auto* tv1 = mul(tv0, tv0); + b.addOutput(tv1); + } + + // Get TV names before swap + auto a_tv0_name = a.inputs()[0]->name(); + auto b_tv0_name = b.inputs()[0]->name(); + + Fusion::swap(a, b); + + // After swap, a has b's old data and vice versa + EXPECT_EQ(a.inputs()[0]->name(), b_tv0_name); + EXPECT_EQ(b.inputs()[0]->name(), a_tv0_name); + + // Adding new TVs after swap should work with correct counters + { + FusionGuard fg_a(&a); + auto* new_tv = + add(a.inputs()[0]->as(), a.inputs()[0]->as()); + // New TV should get a valid name (not crash) + EXPECT_GE(new_tv->name(), 0); + } +} + +TEST_F(Phase2ContainerTest, NameCountersAfterClearAndRebuild) { + // After Fusion::clear(), name counters should reset so new vals start at 0 + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + EXPECT_EQ(tv0->name(), 0); + EXPECT_EQ(tv1->name(), 1); + + a.clear(); + + // After clear, new vals should start at 0 again + auto* tv0_new = makeSymbolicTensor(2); + a.addInput(tv0_new); + EXPECT_EQ(tv0_new->name(), 0); +} + +} // namespace nvfuser