From d6c9b7c024cc6b4c3766198df186a63231116737 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 18:29:02 -0800 Subject: [PATCH 1/9] Per-Fusion Special Vals Moved special values (`zero_val_`, `one_val_`, `true_val_`, `false_val_`, `magic_zero_val_`) from `IrContainer` to the `Fusion` class. This ensures that with shared containers, each Fusion has its own special values, preventing ownership conflicts when one Fusion is destroyed. **Option Implemented:** Option A (Move Special Values to Fusion) as recommended in the prompt. Added private members and public accessors to Fusion class: ```cpp // Phase 2: Per-Fusion special values // With shared containers, each Fusion needs its own special values. // These are raw pointers - memory is owned by IrContainer's vals_up_. // Destroying this Fusion removes these vals via removeStatementsOwnedBy(). Val* zero_val_ = nullptr; Val* one_val_ = nullptr; Val* true_val_ = nullptr; Val* false_val_ = nullptr; NamedScalar* magic_zero_val_ = nullptr; ``` Public accessors: - `Val* zeroVal()` - Returns Index 0 - `Val* oneVal()` - Returns Index 1 - `Val* falseVal()` - Returns Bool false - `Val* trueVal()` - Returns Bool true - `NamedScalar* magicZeroVal()` - Returns magic zero named scalar - `Val* zeroVal(DataType dtype)` - Returns 0 for specified dtype - `Val* oneVal(DataType dtype)` - Returns 1 for specified dtype Implemented lazy creation pattern for all special value accessors: ```cpp Val* Fusion::zeroVal() { if (!zero_val_) { zero_val_ = IrBuilder::createInContainer(this, 0L, DataType::Index); } return zero_val_; } // Similar implementations for oneVal(), falseVal(), trueVal(), magicZeroVal() ``` Updated `Fusion::clear()` to reset special value pointers: ```cpp // Reset per-Fusion special values (they'll be recreated lazily if needed) // The actual Val objects were removed by removeStatementsOwnedBy above. zero_val_ = nullptr; one_val_ = nullptr; true_val_ = nullptr; false_val_ = nullptr; magic_zero_val_ = nullptr; ``` Removed special value members and added documentation comment: ```cpp // Note: Special values (zero_val_, one_val_, true_val_, false_val_, // magic_zero_val_) are now per-Fusion, stored in Fusion class. // This avoids ownership conflicts when multiple Fusions share an IrContainer. // See Fusion::zeroVal(), etc. for the per-Fusion implementation. ``` Removed special value accessor implementations (they're now in Fusion). All call sites were already updated to use `fusion->zeroVal()` instead of `ir_container()->zeroVal()`. Verified with grep that no call sites remain using the old pattern. Added 8 new unit tests for Task 7: 1. **PerFusionSpecialValuesBasic** - Tests that special values are created and owned by the Fusion 2. **SpecialValuesOwnedByFusion** - Tests that special values are tracked in `ownedVals()` 3. **SeparateFusionsHaveOwnSpecialValues** - Tests that two Fusions have different special value objects 4. **DestroyFusionDoesNotAffectOther** - Tests that destroying one Fusion doesn't affect another's special values 5. **SpecialValuesLazyCreation** - Tests that same value is returned on repeated calls 6. **AllSpecialValuesPerFusion** - Tests all five special value accessors 7. **SpecialValuesClearedOnFusionClear** - Tests that `clear()` resets special values 8. **SpecialValuesWithDtype** - Tests `zeroVal(dtype)` and `oneVal(dtype)` accessors ``` [==========] Running 34 tests from 3 test suites. [ PASSED ] 34 tests. ``` ``` [==========] Running 26 tests from 1 test suite. [ PASSED ] 26 tests. ``` Including 8 new Task 7 tests: - `Phase2ContainerTest.PerFusionSpecialValuesBasic` - PASSED - `Phase2ContainerTest.SpecialValuesOwnedByFusion` - PASSED - `Phase2ContainerTest.SeparateFusionsHaveOwnSpecialValues` - PASSED - `Phase2ContainerTest.DestroyFusionDoesNotAffectOther` - PASSED - `Phase2ContainerTest.SpecialValuesLazyCreation` - PASSED - `Phase2ContainerTest.AllSpecialValuesPerFusion` - PASSED - `Phase2ContainerTest.SpecialValuesClearedOnFusionClear` - PASSED - `Phase2ContainerTest.SpecialValuesWithDtype` - PASSED - `csrc/fusion.h` - Added special value members and accessors - `csrc/fusion.cpp` - Added accessor implementations, updated `clear()` - `csrc/ir/container.h` - Removed special values, added comment - `csrc/ir/container.cpp` - Removed accessor implementations - `tests/cpp/test_phase2_container_sharing.cpp` - Added 8 unit tests - [x] Each Fusion has its own special values - [x] Destroying Fusion A doesn't affect Fusion B's special values - [x] Special value accessors (`zeroVal()`, `oneVal()`, etc.) return this Fusion's values - [x] Lazy creation still works (create on first access) - [x] Smoke tests pass (34/34) - [x] Unit tests added (8 tests) - [x] Unit tests pass (26/26 Phase 2 tests) - [x] Code compiles without errors - [x] REPORT.md delivered 1. **Memory ownership:** Special values are raw pointers stored in Fusion, but the actual memory is owned by IrContainer's `vals_up_`. When a Fusion is destroyed, `removeStatementsOwnedBy()` cleans up these vals. 2. **Lazy creation pattern:** Special values are created on first access. This matches the original IrContainer behavior and avoids creating values that aren't needed. 3. **Clear handling:** `Fusion::clear()` now resets special value pointers to nullptr after `removeStatementsOwnedBy()` removes the actual Val objects. This ensures lazy recreation works correctly after clear. 4. **Copy/move handling:** Will be addressed in Tasks 5 and 6. This task just moves the members and accessors. --- csrc/fusion.cpp | 79 +++++++++++++++++++++++++++++++ csrc/fusion.h | 48 +++++++++---------- csrc/ir/container.cpp | 107 ++++++------------------------------------ csrc/ir/container.h | 35 +++++--------- 4 files changed, 128 insertions(+), 141 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index baf1de84614..1d286eda51a 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -137,6 +138,13 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { std::swap(a.outputs_, b.outputs_); std::swap(a.io_alias_, b.io_alias_); + + // Swap per-Fusion special values (Phase 2) + std::swap(a.zero_val_, b.zero_val_); + std::swap(a.one_val_, b.one_val_); + std::swap(a.true_val_, b.true_val_); + std::swap(a.false_val_, b.false_val_); + std::swap(a.magic_zero_val_, b.magic_zero_val_); } std::unique_ptr Fusion::segment( @@ -264,6 +272,14 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); + // Reset per-Fusion special values (they'll be recreated lazily if needed) + // The actual Val objects were removed by removeStatementsOwnedBy above. + zero_val_ = nullptr; + one_val_ = nullptr; + true_val_ = nullptr; + false_val_ = nullptr; + magic_zero_val_ = nullptr; + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -689,6 +705,69 @@ void Fusion::printTransforms() { t_exprs.handle(this); } +// ========================================================================= +// Per-Fusion Special Values (Phase 2) +// Each Fusion has its own special values for safe container sharing. +// ========================================================================= + +Val* Fusion::zeroVal() { + if (!zero_val_) { + zero_val_ = IrBuilder::createInContainer(this, 0L, DataType::Index); + } + return zero_val_; +} + +Val* Fusion::oneVal() { + if (!one_val_) { + one_val_ = IrBuilder::createInContainer(this, 1L, DataType::Index); + } + return one_val_; +} + +Val* Fusion::falseVal() { + if (!false_val_) { + false_val_ = IrBuilder::createInContainer(this, false, DataType::Bool); + } + return false_val_; +} + +Val* Fusion::trueVal() { + if (!true_val_) { + true_val_ = IrBuilder::createInContainer(this, true, DataType::Bool); + } + return true_val_; +} + +NamedScalar* Fusion::magicZeroVal() { + if (!magic_zero_val_) { + magic_zero_val_ = IrBuilder::createInContainer( + this, kMagicZeroName, DataType::Index); + } + return magic_zero_val_; +} + +Val* Fusion::zeroVal(DataType dtype) { + if (dtype == DataType::Index) { + return zeroVal(); + } else if (isBooleanType(dtype)) { + return falseVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 0L, dtype); + } +} + +Val* Fusion::oneVal(DataType dtype) { + if (dtype == DataType::Index) { + return oneVal(); + } else if (isBooleanType(dtype)) { + return trueVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 1L, dtype); + } +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; diff --git a/csrc/fusion.h b/csrc/fusion.h index f02c1b0310d..a3b1735cc16 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -59,6 +59,7 @@ namespace nvfuser { //! checks. class Fusion; +class NamedScalar; class TensorView; class SegmentCandidateFinder; @@ -554,33 +555,16 @@ class NVF_API Fusion : public PolymorphicBase { } // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_container()->zeroVal(); - } - - Val* oneVal() { - return ir_container()->oneVal(); - } - - Val* falseVal() { - return ir_container()->falseVal(); - } - - Val* trueVal() { - return ir_container()->trueVal(); - } - - NamedScalar* magicZeroVal() { - return ir_container()->magicZeroVal(); - } - - Val* zeroVal(DataType dtype) { - return ir_container()->zeroVal(dtype); - } - - Val* oneVal(DataType dtype) { - return ir_container()->oneVal(dtype); - } + // Phase 2: These are now per-Fusion with lazy creation. + // Each Fusion has its own special values to avoid ownership conflicts + // when multiple Fusions share an IrContainer. + Val* zeroVal(); + Val* oneVal(); + Val* falseVal(); + Val* trueVal(); + NamedScalar* magicZeroVal(); + Val* zeroVal(DataType dtype); + Val* oneVal(DataType dtype); Val* metadataOf(Val* val) { return ir_container()->metadataOf(val); @@ -667,6 +651,16 @@ class NVF_API Fusion : public PolymorphicBase { inline static const std::string exact_mappings_key = "exact_mappings"; std::unique_ptr ir_container_; + + // Phase 2: Per-Fusion special values + // With shared containers, each Fusion needs its own special values. + // These are raw pointers - memory is owned by IrContainer's vals_up_. + // Destroying this Fusion removes these vals via removeStatementsOwnedBy(). + Val* zero_val_ = nullptr; + Val* one_val_ = nullptr; + Val* true_val_ = nullptr; + Val* false_val_ = nullptr; + NamedScalar* magic_zero_val_ = nullptr; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 3c54966c87d..c79aefec408 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -7,6 +7,7 @@ // clang-format on #include "ir/container.h" +#include "fusion.h" #include "instrumentation.h" #include "ir/base_nodes.h" #include "ir/builder.h" @@ -84,11 +85,8 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.parent_, b.parent_); - std::swap(a.zero_val_, b.zero_val_); - std::swap(a.one_val_, b.one_val_); - std::swap(a.true_val_, b.true_val_); - std::swap(a.false_val_, b.false_val_); - std::swap(a.magic_zero_val_, b.magic_zero_val_); + // Note: Special values (zero_val_, one_val_, etc.) are now per-Fusion, + // not per-IrContainer. They are swapped as part of the Fusion-level swap. std::swap(a.axioms_, b.axioms_); } @@ -153,12 +151,9 @@ void IrContainer::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it void IrContainer::removeVal(Val* val) { - // Don't remove shortcuts - if (val == true_val_.get() || val == false_val_.get() || - val == one_val_.get() || val == zero_val_.get() || - val == magic_zero_val_.get()) { - return; - } + // Note: Special values (zero_val_, one_val_, etc.) are now per-Fusion, + // stored in Fusion class. They are registered as normal vals and can + // be removed like any other val. NVF_ERROR( vals_.find(val) != vals_.end(), @@ -244,84 +239,9 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } -// Shortcuts for frequently used vals -Val* IrContainer::zeroVal() { - if (!zero_val_) { - auto zero_val = - IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == zero_val); - zero_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return zero_val_.get(); -} - -Val* IrContainer::zeroVal(DataType dtype) { - if (dtype == DataType::Index) { - return zeroVal(); - } else if (isBooleanType(dtype)) { - return falseVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 0L, dtype); - } -} - -Val* IrContainer::oneVal() { - if (!one_val_) { - auto one_val = - IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == one_val); - one_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return one_val_.get(); -} - -Val* IrContainer::oneVal(DataType dtype) { - if (dtype == DataType::Index) { - return oneVal(); - } else if (isBooleanType(dtype)) { - return trueVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 1L, dtype); - } -} - -Val* IrContainer::falseVal() { - if (!false_val_) { - auto false_val = IrBuilder::createInContainer( - this->parent(), false, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == false_val); - false_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return false_val_.get(); -} - -Val* IrContainer::trueVal() { - if (!true_val_) { - auto true_val = - IrBuilder::createInContainer(this->parent(), true, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == true_val); - true_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return true_val_.get(); -} - -NamedScalar* IrContainer::magicZeroVal() { - if (!magic_zero_val_) { - auto magic_zero = - IrBuilder::create(kMagicZeroName, DataType::Index); - NVF_ERROR(vals_up_.back().get() == magic_zero); - magic_zero_val_ = std::unique_ptr( - vals_up_.back().release()->as()); - vals_up_.pop_back(); - } - return magic_zero_val_.get(); -} +// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal) +// are now per-Fusion. Use Fusion::zeroVal() etc. instead. +// This avoids ownership conflicts when multiple Fusions share an IrContainer. Val* IrContainer::metadataOf(Val* v) { if (metadata_.count(v) == 0) { @@ -338,7 +258,8 @@ void IrContainer::lazyInitAxioms() { if (!axioms_) { axioms_ = std::make_unique>(); axioms_->reserve(kParallelTypeThreads.size() * 3); - auto zero = zeroVal(); + // Use parent()->zeroVal() since special values are now per-Fusion + auto zero = parent()->zeroVal(); for (auto p : kParallelTypeThreads) { auto pidx = NamedScalar::getParallelIndex(p); auto pdim = NamedScalar::getParallelDim(p); @@ -352,13 +273,15 @@ void IrContainer::lazyInitAxioms() { void IrContainer::assumePositive(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); + // Use parent()->zeroVal() since special values are now per-Fusion + axioms_->emplace_back(IrBuilder::gtExpr(val, parent()->zeroVal())); } void IrContainer::assumeNonNegative(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); + // Use parent()->zeroVal() since special values are now per-Fusion + axioms_->emplace_back(IrBuilder::geExpr(val, parent()->zeroVal())); } void IrContainer::removeStatementsCreatedAfter( diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e361b8743ee..f4901de311c 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -80,19 +80,18 @@ class IrContainer { return std::ssize(exprs_); } - // When include_shortcuts is true, it will count the shortcuts like true_val_. + // Note: The include_shortcuts parameter is now deprecated. + // With Phase 2 per-Fusion special values, all vals (including special values) + // are stored in vals_up_, so both vals_ and vals_up_ have the same size. + // This parameter is kept for API compatibility but has no effect. int64_t numVals(bool include_shortcuts) const noexcept { return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); } - // Shortcuts for frequently used vals - NVF_API Val* zeroVal(); - NVF_API Val* oneVal(); - Val* falseVal(); - Val* trueVal(); - NamedScalar* magicZeroVal(); - NVF_API Val* zeroVal(DataType dtype); - NVF_API Val* oneVal(DataType dtype); + // Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal) + // are now per-Fusion. Use Fusion::zeroVal() etc. instead. + // This avoids ownership conflicts when multiple Fusions share an IrContainer. + Val* metadataOf(Val*); // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x @@ -171,19 +170,11 @@ class IrContainer { // Expression names counter StmtNameType expr_name_counter_ = 0; - // Manually store some persistent, frequently used nodes. It's very - // challenging to do this anything but manually as detecting when a container - // may or may not have one of these vals is tricky. Specifically because if - // the container doesn't own it, it's hard to understand from the outside if - // the node may have been removed then re-registered. It could also be tricky - // to know when we're using a different container as in FusionCopy_test - // demonstrates deleting then creating containers can result in the same - // pointer for the container. - std::unique_ptr true_val_; - std::unique_ptr false_val_; - std::unique_ptr one_val_; - std::unique_ptr zero_val_; - std::unique_ptr magic_zero_val_; + // Note: Special values (zero_val_, one_val_, true_val_, false_val_, + // magic_zero_val_) are now per-Fusion, stored in Fusion class. + // This avoids ownership conflicts when multiple Fusions share an IrContainer. + // See Fusion::zeroVal(), etc. for the per-Fusion implementation. + std::unique_ptr> axioms_; std::unordered_map> metadata_; From f273b16ae8b3c829b3ac432f7fc03198a7e181a7 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 20:14:53 -0800 Subject: [PATCH 2/9] Per-Fusion Axioms and Metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moved `axioms_` and `metadata_` from `IrContainer` to the `Fusion` class. This completes the deprecation of `parent_` usage for val-creating methods, which was necessary because `parent_` implies a 1-1 relationship (container → Fusion), but Phase 2 has 1-many (shared containers). Methods that used `parent_` to create vals were moved to Fusion: - `metadataOf(Val*)` - Now uses `v->container()` to get owning Fusion - `axioms()` - Now creates axiom vals owned by `this` Fusion - `assumePositive/assumeNonNegative` - Now adds to `this` Fusion's axioms - Added `axioms_` and `metadata_` private members - Changed method declarations from forwarding to actual implementations - Added includes for `ir/builder.h` and `ir/internal_nodes.h` - Implemented `metadataOf()`, `axioms()`, `assumePositive()`, `assumeNonNegative()` methods - Updated `clear()` to reset `axioms_` and `metadata_` - Removed `metadataOf()`, `axioms()`, `assumePositive()`, `assumeNonNegative()` declarations - Removed `lazyInitAxioms()` declaration - Removed `axioms_` and `metadata_` members - Removed implementations of above methods - Updated `IrContainer::swap` to remove axioms_/metadata_ swapping - Updated `IrContainer::copy` to remove axioms_/metadata_ handling - Updated `IrContainer::clear` to remove axioms_/metadata_ clearing Each Fusion now has its own axioms and metadata cache. This ensures: 1. No ownership conflicts when multiple Fusions share an IrContainer 2. Correct behavior when one Fusion is destroyed (doesn't affect others) 3. Lazy creation pattern preserved (create on first access) This is a prerequisite for the copy/move semantics changes which will swap/transfer these per-Fusion members. --- csrc/fusion.cpp | 54 +++++++++++++++++++++++++++++++++++ csrc/fusion.h | 27 +++++++++--------- csrc/ir/container.cpp | 65 +++++-------------------------------------- csrc/ir/container.h | 24 ++++------------ 4 files changed, 80 insertions(+), 90 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 1d286eda51a..756a9f93310 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -280,6 +282,10 @@ void Fusion::clear() noexcept { false_val_ = nullptr; magic_zero_val_ = nullptr; + // Reset per-Fusion axioms and metadata (Phase 2) + axioms_.reset(); + metadata_.clear(); + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -768,6 +774,54 @@ Val* Fusion::oneVal(DataType dtype) { } } +// ========================================================================= +// Per-Fusion Metadata and Axioms (Phase 2) +// These are per-Fusion to avoid ownership issues with shared containers. +// ========================================================================= + +Val* Fusion::metadataOf(Val* v) { + if (metadata_.count(v) == 0) { + // Create metadata val owned by the same Fusion as v + Fusion* owner = v->container(); + auto metadata_val = + IrBuilder::createInContainer(owner, metaDataTypeOf(v)); + auto metadata_expr = + IrBuilder::createInContainer(owner, metadata_val, v); + metadata_[v] = std::make_pair(metadata_val, metadata_expr); + } + return metadata_.at(v).first; +} + +const std::vector& Fusion::axioms() { + if (!axioms_) { + axioms_ = std::make_unique>(); + axioms_->reserve(kParallelTypeThreads.size() * 3); + auto zero = zeroVal(); + for (auto p : kParallelTypeThreads) { + auto pidx = NamedScalar::getParallelIndex(p); + auto pdim = NamedScalar::getParallelDim(p); + axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); + axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); + axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); + } + } + return *axioms_; +} + +void Fusion::assumePositive(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); +} + +void Fusion::assumeNonNegative(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; diff --git a/csrc/fusion.h b/csrc/fusion.h index a3b1735cc16..5aebc2915d5 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -566,22 +566,15 @@ class NVF_API Fusion : public PolymorphicBase { Val* zeroVal(DataType dtype); Val* oneVal(DataType dtype); - Val* metadataOf(Val* val) { - return ir_container()->metadataOf(val); - } + // Phase 2: Per-Fusion metadata and axioms + // These are now per-Fusion to avoid ownership issues with shared containers. + Val* metadataOf(Val* val); // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_container()->axioms(); - } + const std::vector& axioms(); - void assumePositive(Val* val) { - ir_container()->assumePositive(val); - } - - void assumeNonNegative(Val* val) { - ir_container()->assumeNonNegative(val); - } + void assumePositive(Val* val); + void assumeNonNegative(Val* val); // Statement removal void removeStatementsCreatedAfter( @@ -661,6 +654,14 @@ class NVF_API Fusion : public PolymorphicBase { Val* true_val_ = nullptr; Val* false_val_ = nullptr; NamedScalar* magic_zero_val_ = nullptr; + + // Phase 2: Per-Fusion axioms (CUDA programming assumptions) + // These are per-Fusion to avoid ownership issues with shared containers. + std::unique_ptr> axioms_; + + // Phase 2: Per-Fusion metadata cache + // Maps Val* to (metadata_val, metadata_expr) pairs + std::unordered_map> metadata_; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index c79aefec408..2dafe6d78c4 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -81,17 +81,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - std::swap(a.metadata_, b.metadata_); - std::swap(a.parent_, b.parent_); - // Note: Special values (zero_val_, one_val_, etc.) are now per-Fusion, - // not per-IrContainer. They are swapped as part of the Fusion-level swap. - std::swap(a.axioms_, b.axioms_); + // Note: Special values, axioms, and metadata are now per-Fusion, + // not per-IrContainer. They are handled by Fusion::swap. } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->clear(); + IrCloner ir_cloner(to->parent()); // Copy values in deterministic order @@ -113,14 +111,7 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; - if (from->axioms_ != nullptr) { - to->axioms_ = std::make_unique>(); - for (auto pred : *from->axioms_) { - to->axioms_->push_back(ir_cloner.clone(pred)); - } - } - - to->metadata_ = ir_cloner.clone(from->metadata_); + // Note: axioms and metadata are now per-Fusion, handled by Fusion::copy return ir_cloner; } @@ -201,9 +192,7 @@ void IrContainer::clear() noexcept { vals_up_.clear(); exprs_.clear(); exprs_up_.clear(); - axioms_.reset(); val_type_name_map_.clear(); - metadata_.clear(); expr_name_counter_ = 0; } @@ -239,51 +228,11 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } -// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal) -// are now per-Fusion. Use Fusion::zeroVal() etc. instead. +// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal), +// metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(), +// Fusion::metadataOf(), Fusion::axioms(), etc. instead. // This avoids ownership conflicts when multiple Fusions share an IrContainer. -Val* IrContainer::metadataOf(Val* v) { - if (metadata_.count(v) == 0) { - auto metadata_val = - IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); - auto metadata_expr = IrBuilder::createInContainer( - this->parent(), metadata_val, v); - metadata_[v] = std::make_pair(metadata_val, metadata_expr); - } - return metadata_.at(v).first; -} - -void IrContainer::lazyInitAxioms() { - if (!axioms_) { - axioms_ = std::make_unique>(); - axioms_->reserve(kParallelTypeThreads.size() * 3); - // Use parent()->zeroVal() since special values are now per-Fusion - auto zero = parent()->zeroVal(); - for (auto p : kParallelTypeThreads) { - auto pidx = NamedScalar::getParallelIndex(p); - auto pdim = NamedScalar::getParallelDim(p); - axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); - axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); - axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); - } - } -} - -void IrContainer::assumePositive(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - // Use parent()->zeroVal() since special values are now per-Fusion - axioms_->emplace_back(IrBuilder::gtExpr(val, parent()->zeroVal())); -} - -void IrContainer::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - // Use parent()->zeroVal() since special values are now per-Fusion - axioms_->emplace_back(IrBuilder::geExpr(val, parent()->zeroVal())); -} - void IrContainer::removeStatementsCreatedAfter( int64_t prev_num_exprs, int64_t prev_num_vals) { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index f4901de311c..f7ef34ec8f8 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -88,21 +88,11 @@ class IrContainer { return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); } - // Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal) - // are now per-Fusion. Use Fusion::zeroVal() etc. instead. + // Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal), + // metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(), + // Fusion::metadataOf(), Fusion::axioms(), etc. instead. // This avoids ownership conflicts when multiple Fusions share an IrContainer. - Val* metadataOf(Val*); - - // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x - const std::vector& axioms() { - lazyInitAxioms(); - return *axioms_; - } - - void assumePositive(Val* val); - void assumeNonNegative(Val* val); - protected: static IrCloner copy(const IrContainer* from, IrContainer* to); @@ -136,8 +126,6 @@ class IrContainer { void clear() noexcept; - void lazyInitAxioms(); - friend class StatementGuard; // A simple garbage collection mechanism to remove all Exprs and Vals that @@ -173,10 +161,8 @@ class IrContainer { // Note: Special values (zero_val_, one_val_, true_val_, false_val_, // magic_zero_val_) are now per-Fusion, stored in Fusion class. // This avoids ownership conflicts when multiple Fusions share an IrContainer. - // See Fusion::zeroVal(), etc. for the per-Fusion implementation. - - std::unique_ptr> axioms_; - std::unordered_map> metadata_; + // See Fusion::zeroVal(), Fusion::axioms(), Fusion::metadataOf(), etc. + // for the per-Fusion implementations. public: Fusion* parent() const { From cd2acf6c4578d03ddefb7cf3c21db66e5c9172cb Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 13:46:33 -0800 Subject: [PATCH 3/9] Basic shared_ptr Transition Transitioned Fusion's container ownership from `unique_ptr` to `shared_ptr` with automatic Fusion registration/unregistration during construction/destruction. - Added `#include ` header - Changed `std::unique_ptr ir_container_` to `std::shared_ptr ir_container_` - Added `ir_container_ptr()` method to return the shared_ptr directly (for tests that need to hold reference beyond Fusion lifetime) - Updated default constructor to use `std::make_shared()` and call `addFusion(this)` for registration - Updated destructor to call `removeFusion(this)` before clearing - Added 4 new tests for Task 3: - `BasicFusionLifecycle`: Create Fusion, add inputs/outputs, destroy - verifies no crashes - `FusionAutoRegistration`: Verifies new Fusion automatically registers (sharingCount == 1) - `FusionDestructorCleanup`: Verifies destructor unregisters and cleans up Statements - `ContainerAccessor`: Verifies `ir_container_ptr()` returns valid shared_ptr - Updated Task 2 tests to account for auto-registration in constructor ``` [ PASSED ] Phase2ContainerTest.LockingBasic [ PASSED ] Phase2ContainerTest.ConcurrentReads [ PASSED ] Phase2ContainerTest.FusionRegistration [ PASSED ] Phase2ContainerTest.FusionTransfer [ PASSED ] Phase2ContainerTest.MultipleRegistration [ PASSED ] Phase2ContainerTest.StatementCleanup [ PASSED ] Phase2ContainerTest.BasicFusionLifecycle [ PASSED ] Phase2ContainerTest.FusionAutoRegistration [ PASSED ] Phase2ContainerTest.FusionDestructorCleanup [ PASSED ] Phase2ContainerTest.ContainerAccessor ``` ``` [ PASSED ] 34 tests including: - FusionCopy_CUDA - FusionMove_CUDA - FusionClear_CUDA - All AbstractTensorTest.* - All NVFuserTest.FusionHash* ``` 1. **Auto-registration in constructor**: The Fusion constructor now calls `ir_container_->addFusion(this)` after creating the container. This ensures every Fusion is always tracked. 2. **Auto-unregistration in destructor**: The destructor calls `ir_container_->removeFusion(this)` which: - Decrements the sharing count - Cleans up Statements owned by this Fusion - Works correctly even when container is shared (other Fusions' Statements preserved) 3. **Added `ir_container_ptr()` method**: Returns `std::shared_ptr` for cases where code needs to hold a reference to the container beyond the Fusion's lifetime (e.g., testing Statement cleanup after Fusion destruction). 4. **Task 2 test updates**: The previous Task 2 tests assumed Fusions weren't auto-registered. Updated them to use separate IrContainer instances for testing the registration mechanism in isolation. | File | Changes | |------|---------| | `csrc/fusion.h` | Added `` header, changed `unique_ptr` to `shared_ptr`, added `ir_container_ptr()` | | `csrc/fusion.cpp` | Updated constructor (make_shared + addFusion), updated destructor (removeFusion) | | `tests/cpp/test_phase2_container_sharing.cpp` | Added 4 new tests, updated 3 Task 2 tests | --- csrc/fusion.cpp | 6 +++++- csrc/fusion.h | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 756a9f93310..9916fdcb19a 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -220,8 +220,9 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Default constructor -Fusion::Fusion() : ir_container_(std::make_unique()) { +Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->parent_ = this; + ir_container_->addFusion(this); // Register with container for Phase 2 } // Copy constructor @@ -252,6 +253,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept { } Fusion::~Fusion() { + if (ir_container_) { + ir_container_->removeFusion(this); // Unregister before destruction + } clear(); } diff --git a/csrc/fusion.h b/csrc/fusion.h index 5aebc2915d5..e885bc1987a 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -163,7 +164,11 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container_.get(); } - public: + // Return the shared_ptr to the container (for Phase 2 container sharing) + std::shared_ptr ir_container_ptr() const { + return ir_container_; + } + // Registration (public API with passkey) virtual void registerStmt(IrBuilderPasskey, Statement* stmt) { if (stmt->isVal()) { @@ -643,7 +648,7 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; - std::unique_ptr ir_container_; + std::shared_ptr ir_container_; // Phase 2: Per-Fusion special values // With shared containers, each Fusion needs its own special values. From d29216b868083c088b2d32a8b3386a813994d141 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 13:02:42 -0800 Subject: [PATCH 4/9] Added mutex and locking infrastructure to IrContainer --- CMakeLists.txt | 1 + csrc/ir/container.cpp | 68 ++++++++++++++-- csrc/ir/container.h | 19 ++--- tests/cpp/test_phase2_container_sharing.cpp | 90 +++++++++++++++++++++ 4 files changed, 164 insertions(+), 14 deletions(-) create mode 100644 tests/cpp/test_phase2_container_sharing.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d007b40c90d..c82f5609467 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -912,6 +912,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_pdl.cpp ${NVFUSER_ROOT}/tests/cpp/test_persistent_buffer.cpp + ${NVFUSER_ROOT}/tests/cpp/test_phase2_container_sharing.cpp ${NVFUSER_ROOT}/tests/cpp/test_pointwise.cpp ${NVFUSER_ROOT}/tests/cpp/test_polymorphic_value.cpp ${NVFUSER_ROOT}/tests/cpp/test_predicate_elimination.cpp diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 2dafe6d78c4..e72d4c2821e 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -18,6 +18,7 @@ namespace nvfuser { //! Return values in insertion order const std::deque IrContainer::deterministic_vals() const noexcept { + std::shared_lock lock(mutex_); std::deque vals_deque; std::transform( vals_up_.begin(), @@ -29,6 +30,7 @@ const std::deque IrContainer::deterministic_vals() const noexcept { //! Return expression in insertion order const std::deque IrContainer::deterministic_exprs() const noexcept { + std::shared_lock lock(mutex_); std::deque exprs_deque; std::transform( exprs_up_.begin(), @@ -41,6 +43,7 @@ const std::deque IrContainer::deterministic_exprs() const noexcept { //! Return mapping from value to integer id const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map vals_map; int64_t count = 0; std::transform( @@ -56,6 +59,7 @@ const std::unordered_map IrContainer::deterministic_vals_map() //! Return mapping from expression to integer id const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map exprs_map; int64_t count = 0; std::transform( @@ -68,9 +72,39 @@ const std::unordered_map IrContainer::deterministic_exprs_map() return exprs_map; } +const std::unordered_set& IrContainer::unordered_exprs() const noexcept { + // Note: Returns reference - caller responsible for not holding across + // concurrent modifications. Lock provides snapshot consistency during call. + std::shared_lock lock(mutex_); + return exprs_; +} + +const std::unordered_set& IrContainer::vals() const noexcept { + // Note: Returns reference - caller responsible for not holding across + // concurrent modifications. Lock provides snapshot consistency during call. + std::shared_lock lock(mutex_); + return vals_; +} + +int64_t IrContainer::numExprs() const noexcept { + std::shared_lock lock(mutex_); + return std::ssize(exprs_); +} + +int64_t IrContainer::numVals(bool include_shortcuts) const noexcept { + std::shared_lock lock(mutex_); + return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); +} + void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); + // Lock both containers in consistent order to avoid deadlock + // Use std::lock to lock both mutexes atomically + std::unique_lock lock_a(a.mutex_, std::defer_lock); + std::unique_lock lock_b(b.mutex_, std::defer_lock); + std::lock(lock_a, lock_b); + // Swap the content std::swap(a.vals_up_, b.vals_up_); std::swap(a.vals_, b.vals_); @@ -88,22 +122,46 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { - to->clear(); + // Lock both containers: shared for reading from, unique for writing to + std::shared_lock lock_from(from->mutex_); + std::unique_lock lock_to(to->mutex_); + + // Clear without calling clear() which would try to re-acquire the lock + to->vals_.clear(); + to->vals_up_.clear(); + to->exprs_.clear(); + to->exprs_up_.clear(); + to->axioms_.reset(); + to->val_type_name_map_.clear(); + to->metadata_.clear(); + to->expr_name_counter_ = 0; IrCloner ir_cloner(to->parent()); // Copy values in deterministic order // deterministic_vals can contain special values like one_val_, zero_val_, etc // that are not registered in the container. - for (auto val : from->deterministic_vals()) { - if (from->vals().count(val) > 0) { + std::deque from_vals; + std::transform( + from->vals_up_.begin(), + from->vals_up_.end(), + std::back_inserter(from_vals), + [](const std::unique_ptr& val_up) { return val_up.get(); }); + for (auto val : from_vals) { + if (from->vals_.count(val) > 0) { to->vals_.insert(ir_cloner.clone(val)); } } // Copy expressions in deterministic order - for (auto expr : from->deterministic_exprs()) { - if (from->unordered_exprs().count(expr) > 0) { + std::deque from_exprs; + std::transform( + from->exprs_up_.begin(), + from->exprs_up_.end(), + std::back_inserter(from_exprs), + [](const std::unique_ptr& expr_up) { return expr_up.get(); }); + for (auto expr : from_exprs) { + if (from->exprs_.count(expr) > 0) { to->exprs_.insert(ir_cloner.clone(expr)); } } diff --git a/csrc/ir/container.h b/csrc/ir/container.h index f7ef34ec8f8..3c25ae49e78 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -67,18 +68,14 @@ class IrContainer { //! Return the set of Exprs registered with this fusion. Warning: This will //! return exprs outside inputs/outputs, so can be unsafe for use with //! segmented fusions. - const std::unordered_set& unordered_exprs() const noexcept { - return exprs_; - } + //! Note: Returns reference - caller must not hold across concurrent mods + const std::unordered_set& unordered_exprs() const noexcept; //! Return the set of Vals registered with this fusion - const std::unordered_set& vals() const noexcept { - return vals_; - } + //! Note: Returns reference - caller must not hold across concurrent mods + const std::unordered_set& vals() const noexcept; - int64_t numExprs() const noexcept { - return std::ssize(exprs_); - } + int64_t numExprs() const noexcept; // Note: The include_shortcuts parameter is now deprecated. // With Phase 2 per-Fusion special values, all vals (including special values) @@ -94,6 +91,10 @@ class IrContainer { // This avoids ownership conflicts when multiple Fusions share an IrContainer. protected: + // Mutex for thread-safe access when container is shared between Fusions + // mutable because we need to lock in const methods + mutable std::shared_mutex mutex_; + static IrCloner copy(const IrContainer* from, IrContainer* to); static void swap(IrContainer& a, IrContainer& b) noexcept; diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp new file mode 100644 index 00000000000..dd05962a266 --- /dev/null +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -0,0 +1,90 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include +#include +#include + +#include "fusion.h" +#include "ir/container.h" +#include "ops/all_ops.h" +#include "tests/cpp/utils.h" + +namespace nvfuser { + +// Test class for Phase 2 container sharing tests +class Phase2ContainerTest : public NVFuserTest { + protected: + void SetUp() override { + NVFuserTest::SetUp(); + } + void TearDown() override { + NVFuserTest::TearDown(); + } +}; + +// ============================================================================= +// Task 1 Tests: Locking Infrastructure +// ============================================================================= + +TEST_F(Phase2ContainerTest, LockingBasic) { + // Verify basic operations still work with locking in place + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Verify container has expected contents + // Use vals() and unordered_exprs() which return references to container data + EXPECT_GT(fusion.vals().size(), 0); + EXPECT_GT(fusion.unordered_exprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, ConcurrentReads) { + // Multiple threads can read simultaneously without data races + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + std::vector threads; + std::atomic read_count{0}; + + // Spawn multiple reader threads + for (int i = 0; i < 4; ++i) { + threads.emplace_back([&]() { + for (int j = 0; j < 100; ++j) { + // Access vals and unordered_exprs through fusion's forwarding methods + // These return const references to the underlying container data + const auto& vals = fusion.vals(); + const auto& exprs = fusion.unordered_exprs(); + // Just access sizes to verify no crashes under concurrent access + (void)vals.size(); + (void)exprs.size(); + read_count++; + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + EXPECT_EQ(read_count.load(), 400); +} + +} // namespace nvfuser From 94b6e19fe5b7440e93a2466980285610695c7bc9 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 13:25:48 -0800 Subject: [PATCH 5/9] Fusion Tracking Infrastructure Added Fusion registration and Statement cleanup capabilities to IrContainer. This enables tracking which Fusions share a container and cleaning up Statements when a Fusion is destroyed. Added public methods for Fusion tracking: - `addFusion(Fusion*)` - register a Fusion as sharing this container - `removeFusion(Fusion*)` - unregister and cleanup owned Statements - `transferFusion(Fusion* from, Fusion* to)` - for move operations - `sharingCount()` - number of Fusions sharing this container - `hasMultipleFusions()` - whether multiple Fusions share this container - `sharingFusions()` - get the set of sharing Fusions Added protected members: - `std::unordered_set sharing_fusions_` - tracks registered Fusions - `removeStatementsOwnedByUnlocked(Fusion*)` - internal cleanup helper Implemented all Fusion tracking methods: - `addFusion()` - inserts Fusion into `sharing_fusions_` (unique_lock) - `removeFusion()` - removes from set and cleans up owned Statements (unique_lock) - `transferFusion()` - atomic transfer of registration (unique_lock) - `sharingCount()`, `hasMultipleFusions()`, `sharingFusions()` - read accessors (shared_lock) - `removeStatementsOwnedByUnlocked()` - iterates vals/exprs and removes those owned by the given Fusion Also updated `swap()` to include `sharing_fusions_` in the swap. The `removeStatementsOwnedByUnlocked()` function removes: 1. All Vals in `vals_up_` owned by the Fusion 2. All shortcut Vals (`zero_val_`, `one_val_`, `true_val_`, `false_val_`, `magic_zero_val_`) if owned by the Fusion 3. All Exprs in `exprs_up_` owned by the Fusion Ownership is determined by checking `statement->container() == fusion`. Made `ir_container()` accessor public (was protected). This is needed for Phase 2 shared_ptr support where external code needs to access the underlying container. Added 4 new tests: - `FusionRegistration` - verifies add/remove counting - `FusionTransfer` - verifies transfer updates tracking correctly - `MultipleRegistration` - verifies multiple Fusions can register - `StatementCleanup` - verifies Statement cleanup on removeFusion --- ``` [ PASSED ] 34 tests. ``` ``` [ PASSED ] 6 tests. - Phase2ContainerTest.LockingBasic - Phase2ContainerTest.ConcurrentReads - Phase2ContainerTest.FusionRegistration - Phase2ContainerTest.FusionTransfer - Phase2ContainerTest.MultipleRegistration - Phase2ContainerTest.StatementCleanup ``` --- | File | Change Type | |------|-------------| | `csrc/ir/container.h` | Added tracking methods and members | | `csrc/ir/container.cpp` | Implemented tracking methods | | `csrc/fusion.h` | Made `ir_container()` public | | `tests/cpp/test_phase2_container_sharing.cpp` | Added 4 tests | --- All tracking methods use the existing `mutex_`: - `addFusion()`, `removeFusion()`, `transferFusion()` - unique_lock (write) - `sharingCount()`, `hasMultipleFusions()`, `sharingFusions()` - shared_lock (read) Statements store their owning Fusion via `ir_container_` member (in `base_nodes.h`). When a Fusion is removed from a shared container: 1. The Fusion is unregistered from `sharing_fusions_` 2. All Statements where `container() == fusion` are removed from the container This ensures that when a Fusion is destroyed, its IR nodes don't pollute the shared container. --- csrc/fusion.h | 4 +- csrc/ir/container.cpp | 67 +++++++++++- csrc/ir/container.h | 30 ++++++ tests/cpp/test_phase2_container_sharing.cpp | 112 ++++++++++++++++++++ 4 files changed, 209 insertions(+), 4 deletions(-) diff --git a/csrc/fusion.h b/csrc/fusion.h index e885bc1987a..87317eb6a20 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -148,8 +148,8 @@ class AliasInfoMap { class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; - protected: - // Direct access to underlying container + public: + // Direct access to underlying container (for Phase 2 shared_ptr support) IrContainer* ir_container() { NVF_ERROR( ir_container_.get() != nullptr, diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index e72d4c2821e..c06b59c768a 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -96,6 +96,71 @@ int64_t IrContainer::numVals(bool include_shortcuts) const noexcept { return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); } +// ========================================================================= +// Fusion tracking for shared container support (Phase 2) +// ========================================================================= + +void IrContainer::addFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); + sharing_fusions_.insert(fusion); +} + +void IrContainer::removeFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); + sharing_fusions_.erase(fusion); + removeStatementsOwnedByUnlocked(fusion); +} + +void IrContainer::transferFusion(Fusion* from, Fusion* to) { + std::unique_lock lock(mutex_); + sharing_fusions_.erase(from); + sharing_fusions_.insert(to); + // Note: Statements retain their container() pointer - they don't need + // to be updated because container() returns Fusion* which points to + // the owning Fusion, and that ownership is what we're transferring. +} + +size_t IrContainer::sharingCount() const { + std::shared_lock lock(mutex_); + return sharing_fusions_.size(); +} + +bool IrContainer::hasMultipleFusions() const { + std::shared_lock lock(mutex_); + return sharing_fusions_.size() > 1; +} + +const std::unordered_set& IrContainer::sharingFusions() const { + std::shared_lock lock(mutex_); + return sharing_fusions_; +} + +void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { + // Remove all Vals owned by this Fusion + for (auto it = vals_up_.begin(); it != vals_up_.end();) { + Val* val = it->get(); + // Check if this Val's container points to the Fusion being removed + if (val->container() == fusion) { + vals_.erase(val); + it = vals_up_.erase(it); + } else { + ++it; + } + } + + // Remove all Exprs owned by this Fusion + for (auto it = exprs_up_.begin(); it != exprs_up_.end();) { + Expr* expr = it->get(); + // Check if this Expr's container points to the Fusion being removed + if (expr->container() == fusion) { + exprs_.erase(expr); + it = exprs_up_.erase(it); + } else { + ++it; + } + } +} + void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); @@ -131,9 +196,7 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->vals_up_.clear(); to->exprs_.clear(); to->exprs_up_.clear(); - to->axioms_.reset(); to->val_type_name_map_.clear(); - to->metadata_.clear(); to->expr_name_counter_ = 0; IrCloner ir_cloner(to->parent()); diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 3c25ae49e78..5fb04f7d8f2 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -90,11 +90,41 @@ class IrContainer { // Fusion::metadataOf(), Fusion::axioms(), etc. instead. // This avoids ownership conflicts when multiple Fusions share an IrContainer. + public: + // ========================================================================= + // Fusion tracking for shared container support (Phase 2) + // ========================================================================= + + //! Register a Fusion as sharing this container + void addFusion(Fusion* fusion); + + //! Unregister a Fusion and cleanup its owned Statements + void removeFusion(Fusion* fusion); + + //! Transfer registration from one Fusion to another (for move operations) + void transferFusion(Fusion* from, Fusion* to); + + //! Number of Fusions sharing this container + size_t sharingCount() const; + + //! Whether multiple Fusions share this container + bool hasMultipleFusions() const; + + //! Get the set of Fusions sharing this container + const std::unordered_set& sharingFusions() const; + protected: // Mutex for thread-safe access when container is shared between Fusions // mutable because we need to lock in const methods mutable std::shared_mutex mutex_; + //! Fusions that share this container (for Phase 2 shared_ptr ownership) + std::unordered_set sharing_fusions_; + + //! Remove all Statements owned by a specific Fusion (internal helper) + //! Caller must hold unique_lock on mutex_ + void removeStatementsOwnedByUnlocked(Fusion* fusion); + static IrCloner copy(const IrContainer* from, IrContainer* to); static void swap(IrContainer& a, IrContainer& b) noexcept; diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp index dd05962a266..a8d08a242f8 100644 --- a/tests/cpp/test_phase2_container_sharing.cpp +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -87,4 +87,116 @@ TEST_F(Phase2ContainerTest, ConcurrentReads) { EXPECT_EQ(read_count.load(), 400); } +// ============================================================================= +// Task 2 Tests: Fusion Tracking Infrastructure +// ============================================================================= + +TEST_F(Phase2ContainerTest, FusionRegistration) { + // Test that addFusion increments count, removeFusion decrements + Fusion fusion; + FusionGuard fg(&fusion); + + // Get the IrContainer through Fusion + auto& container = *fusion.ir_container(); + + // Initially no Fusions registered (Phase 1 doesn't use registration yet) + EXPECT_EQ(container.sharingCount(), 0); + + // Register the Fusion + container.addFusion(&fusion); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_FALSE(container.hasMultipleFusions()); + + // Create another Fusion and register it with the same container + // (simulating shared_ptr sharing that will happen in later tasks) + Fusion fusion2; + container.addFusion(&fusion2); + EXPECT_EQ(container.sharingCount(), 2); + EXPECT_TRUE(container.hasMultipleFusions()); + + // Remove one + container.removeFusion(&fusion2); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_FALSE(container.hasMultipleFusions()); + + // Remove the other + container.removeFusion(&fusion); + EXPECT_EQ(container.sharingCount(), 0); +} + +TEST_F(Phase2ContainerTest, FusionTransfer) { + // Test transferFusion correctly updates tracking + Fusion fusion1; + Fusion fusion2; + + auto& container = *fusion1.ir_container(); + + // Register fusion1 + container.addFusion(&fusion1); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_TRUE(container.sharingFusions().count(&fusion1) > 0); + EXPECT_TRUE(container.sharingFusions().count(&fusion2) == 0); + + // Transfer from fusion1 to fusion2 + container.transferFusion(&fusion1, &fusion2); + EXPECT_EQ(container.sharingCount(), 1); + EXPECT_TRUE(container.sharingFusions().count(&fusion1) == 0); + EXPECT_TRUE(container.sharingFusions().count(&fusion2) > 0); +} + +TEST_F(Phase2ContainerTest, MultipleRegistration) { + // Test multiple Fusions can register with same container + Fusion fusion1; + Fusion fusion2; + Fusion fusion3; + + auto& container = *fusion1.ir_container(); + + container.addFusion(&fusion1); + container.addFusion(&fusion2); + container.addFusion(&fusion3); + + EXPECT_EQ(container.sharingCount(), 3); + EXPECT_TRUE(container.hasMultipleFusions()); + + // Verify all are registered + const auto& fusions = container.sharingFusions(); + EXPECT_TRUE(fusions.count(&fusion1) > 0); + EXPECT_TRUE(fusions.count(&fusion2) > 0); + EXPECT_TRUE(fusions.count(&fusion3) > 0); +} + +TEST_F(Phase2ContainerTest, StatementCleanup) { + // Test that removeFusion removes only Statements owned by that Fusion + // This is tricky to test directly because Statements are tied to their + // container at construction. We test the basic mechanism works. + + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + size_t initial_vals = container.vals().size(); + size_t initial_exprs = container.unordered_exprs().size(); + + EXPECT_GT(initial_vals, 0); + EXPECT_GT(initial_exprs, 0); + + // Register fusion + container.addFusion(&fusion); + + // When we remove fusion, its Statements should be cleaned up + // (all Statements in this test are owned by fusion) + container.removeFusion(&fusion); + + // After removal, the Statements owned by fusion should be removed + EXPECT_EQ(container.vals().size(), 0); + EXPECT_EQ(container.unordered_exprs().size(), 0); +} + } // namespace nvfuser From fdb53df7ed518bfc07d9e2163344eef4b9ca3af7 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 17:27:32 -0800 Subject: [PATCH 6/9] Per Fusion Tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implemented infrastructure for per-Fusion statement tracking so each Fusion can efficiently access only its own statements when sharing an IrContainer with other Fusions. This is a prerequisite for copy/move semantics in later tasks. 1. **Per-Fusion Tracking Data Structures** (`container.h`) - Added `per_fusion_vals_`: Maps each Fusion to its owned Vals - Added `per_fusion_exprs_`: Maps each Fusion to its owned Exprs 2. **New IrContainer Methods** (`container.h/cpp`) - `valsOwnedBy(Fusion*)`: Returns Vals owned by specific Fusion - `exprsOwnedBy(Fusion*)`: Returns Exprs owned by specific Fusion - `transferStatementOwnership(Fusion*, Fusion*)`: For move operations - `removeStatementsOwnedBy(Fusion*)`: Public API to remove Fusion's statements 3. **New Fusion Accessor Methods** (`fusion.h`) - `ownedVals()`: Returns only THIS Fusion's Vals (not all in container) - `ownedExprs()`: Returns only THIS Fusion's Exprs (not all in container) 4. **Updated Registration** (`container.cpp`) - `registerVal()`: Now updates per-Fusion tracking - `registerExpr()`: Now updates per-Fusion tracking - `removeVal()`: Now cleans up per-Fusion tracking - `removeExpr()`: Now cleans up per-Fusion tracking 5. **Updated Fusion::clear()** (`fusion.cpp`) - Changed from `ir_container()->clear()` (clears entire container) - To `ir_container_->removeStatementsOwnedBy(this)` (only clears THIS Fusion's statements) - Critical for Invariant 4: `Fusion::clear()` must only affect this Fusion's state ``` [==========] Running 34 tests from 3 test suites. [ PASSED ] 34 tests. ``` - `AbstractTensorTest.*` (28 tests): PASS - `Gpu1Test.FusionClear_CUDA`: PASS - `Gpu1Test.FusionCopy_CUDA`: PASS - `Gpu1Test.FusionMove_CUDA`: PASS - `NVFuserTest.FusionHash*` (3 tests): PASS ``` [==========] Running 18 tests from 1 test suite. [ PASSED ] 18 tests. ``` New tests added for Task 4: - `PerFusionValsTracking`: Verifies ownedVals() returns only this Fusion's vals - `PerFusionExprsTracking`: Verifies ownedExprs() returns only this Fusion's exprs - `ValsOwnedByAPI`: Tests IrContainer::valsOwnedBy() API directly - `ExprsOwnedByAPI`: Tests IrContainer::exprsOwnedBy() API directly - `RegisterUpdatesPerFusionTracking`: Verifies registration updates tracking - `TransferStatementOwnership`: Tests transferStatementOwnership for moves - `ClearOnlyAffectsOwnedStatements`: Verifies clear() only affects this Fusion - `RemoveStatementsOwnedByAPI`: Tests public removeStatementsOwnedBy API 1. **Thread Safety**: All new methods use existing mutex_ infrastructure - `valsOwnedBy()`/`exprsOwnedBy()` acquire shared_lock - `transferStatementOwnership()`/`removeStatementsOwnedBy()` acquire unique_lock 2. **Empty Set Handling**: Return static empty set when Fusion has no statements ```cpp const std::unordered_set& IrContainer::valsOwnedBy(Fusion* fusion) const { std::shared_lock lock(mutex_); static const std::unordered_set empty; auto it = per_fusion_vals_.find(fusion); return it != per_fusion_vals_.end() ? it->second : empty; } ``` 3. **Backward Compatibility**: - Existing `vals()` and `unordered_exprs()` unchanged - With single Fusion (Phase 1 pattern): `ownedVals() == vals()` - With shared container (Phase 2): `ownedVals() ⊆ vals()` | Level | Vals | Exprs | Description | |-------|------|-------|-------------| | **IR Traversal** | `exprs()` | N/A | Reachable from I/O (existing, unchanged) | | **All in Container** | `vals()` | `unordered_exprs()` | All in shared container | | **Owned by Fusion** | `ownedVals()` | `ownedExprs()` | Only THIS Fusion's statements (NEW) | --- csrc/fusion.cpp | 14 +- csrc/fusion.h | 21 ++ csrc/ir/container.cpp | 82 ++++++++ csrc/ir/container.h | 27 +++ tests/cpp/test_phase2_container_sharing.cpp | 217 ++++++++++++++++++++ 5 files changed, 357 insertions(+), 4 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 9916fdcb19a..5bb12613e33 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -265,11 +265,17 @@ void Fusion::clear() noexcept { // constructor of Trace, which could throw an exception. // FUSER_PERF_SCOPE("Fusion clear"); - // Clear container contents instead of destroying it - // This preserves the container object so Statement pointers don't become - // dangling - ir_container()->clear(); + // Phase 2 (Task 4): Only clear THIS Fusion's statements, not the entire + // container. With shared containers, calling ir_container()->clear() would + // break other Fusions sharing the container. + // + // Phase 1: ir_container()->clear() was equivalent to this (1:1 relationship) + // Phase 2: Must filter by ownership to avoid affecting other Fusions + if (ir_container_) { + ir_container_->removeStatementsOwnedBy(this); + } + // Clear Fusion-level state (these are per-Fusion, not in the container) inputs_.clear(); outputs_.clear(); diff --git a/csrc/fusion.h b/csrc/fusion.h index 87317eb6a20..f4fb59be122 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -550,6 +550,27 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container()->vals(); } + // Per-Fusion Statement Access (Phase 2 Task 4) + // These methods return only statements owned by THIS Fusion, + // not all statements in the shared container. + // + // Phase 1 (unique_ptr): ownedVals() == vals() (1:1 relationship) + // Phase 2 (shared_ptr): ownedVals() ⊆ vals() (filtering by ownership) + + //! Return only Vals owned by this Fusion + //! Unlike vals() which returns ALL vals in the (possibly shared) container, + //! this returns only vals where val->container() == this + const std::unordered_set& ownedVals() const { + return ir_container()->valsOwnedBy(const_cast(this)); + } + + //! Return only Exprs owned by this Fusion + //! Unlike unordered_exprs() which returns ALL exprs in the container, + //! this returns only exprs where expr->container() == this + const std::unordered_set& ownedExprs() const { + return ir_container()->exprsOwnedBy(const_cast(this)); + } + // Count queries int64_t numExprs() const noexcept { return ir_container()->numExprs(); diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index c06b59c768a..8fe3c2f8692 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -135,6 +135,52 @@ const std::unordered_set& IrContainer::sharingFusions() const { return sharing_fusions_; } +// ========================================================================= +// Per-Fusion Statement Tracking (Phase 2 Task 4) +// ========================================================================= + +const std::unordered_set& IrContainer::valsOwnedBy(Fusion* fusion) const { + std::shared_lock lock(mutex_); + static const std::unordered_set empty; + auto it = per_fusion_vals_.find(fusion); + return it != per_fusion_vals_.end() ? it->second : empty; +} + +const std::unordered_set& IrContainer::exprsOwnedBy( + Fusion* fusion) const { + std::shared_lock lock(mutex_); + static const std::unordered_set empty; + auto it = per_fusion_exprs_.find(fusion); + return it != per_fusion_exprs_.end() ? it->second : empty; +} + +void IrContainer::transferStatementOwnership(Fusion* from, Fusion* to) { + std::unique_lock lock(mutex_); + + // Transfer vals ownership tracking + auto vals_it = per_fusion_vals_.find(from); + if (vals_it != per_fusion_vals_.end()) { + // Move the set to 'to', merging if 'to' already has entries + auto& to_vals = per_fusion_vals_[to]; + to_vals.insert(vals_it->second.begin(), vals_it->second.end()); + per_fusion_vals_.erase(vals_it); + } + + // Transfer exprs ownership tracking + auto exprs_it = per_fusion_exprs_.find(from); + if (exprs_it != per_fusion_exprs_.end()) { + // Move the set to 'to', merging if 'to' already has entries + auto& to_exprs = per_fusion_exprs_[to]; + to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end()); + per_fusion_exprs_.erase(exprs_it); + } +} + +void IrContainer::removeStatementsOwnedBy(Fusion* fusion) { + std::unique_lock lock(mutex_); + removeStatementsOwnedByUnlocked(fusion); +} + void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { // Remove all Vals owned by this Fusion for (auto it = vals_up_.begin(); it != vals_up_.end();) { @@ -159,6 +205,10 @@ void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { ++it; } } + + // Clean up per-Fusion tracking (Phase 2 Task 4) + per_fusion_vals_.erase(fusion); + per_fusion_exprs_.erase(fusion); } void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { @@ -256,6 +306,14 @@ void IrContainer::removeExpr(Expr* expr) { expr_in_deque != exprs_up_.end(), "Wanted to remove an expression but its unique ptr is missing."); + // Remove from per-Fusion tracking (Phase 2 Task 4) + if (expr->container() != nullptr) { + auto it = per_fusion_exprs_.find(expr->container()); + if (it != per_fusion_exprs_.end()) { + it->second.erase(expr); + } + } + exprs_.erase(expr); exprs_up_.erase(expr_in_deque); } @@ -279,6 +337,14 @@ void IrContainer::removeVal(Val* val) { val_in_deque != vals_up_.end(), "Wanted to remove a value but its unique ptr is missing."); + // Remove from per-Fusion tracking (Phase 2 Task 4) + if (val->container() != nullptr) { + auto it = per_fusion_vals_.find(val->container()); + if (it != per_fusion_vals_.end()) { + it->second.erase(val); + } + } + vals_.erase(val); vals_up_.erase(val_in_deque); } @@ -293,6 +359,12 @@ void IrContainer::registerVal(Val* val) { vals_up_.emplace_back(val); vals_.insert(val); val->setName(IrContainerPasskey(), getValName(val->vtype())); + + // Track per-Fusion ownership (Phase 2 Task 4) + // val->container() returns the owning Fusion + if (val->container() != nullptr) { + per_fusion_vals_[val->container()].insert(val); + } } //! Register expr with this container. @@ -305,6 +377,12 @@ void IrContainer::registerExpr(Expr* expr) { exprs_up_.emplace_back(expr); exprs_.insert(expr); expr->setName(IrContainerPasskey(), getExprName()); + + // Track per-Fusion ownership (Phase 2 Task 4) + // expr->container() returns the owning Fusion + if (expr->container() != nullptr) { + per_fusion_exprs_[expr->container()].insert(expr); + } } void IrContainer::clear() noexcept { @@ -315,6 +393,10 @@ void IrContainer::clear() noexcept { exprs_up_.clear(); val_type_name_map_.clear(); expr_name_counter_ = 0; + + // Clear per-Fusion tracking (Phase 2 Task 4) + per_fusion_vals_.clear(); + per_fusion_exprs_.clear(); } bool IrContainer::inContainer(const Statement* const_stmt) const { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 5fb04f7d8f2..ccdd25f9242 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -113,6 +113,26 @@ class IrContainer { //! Get the set of Fusions sharing this container const std::unordered_set& sharingFusions() const; + // ========================================================================= + // Per-Fusion Statement Tracking (Phase 2 Task 4) + // ========================================================================= + + //! Get Vals owned by a specific Fusion + //! Returns empty set if Fusion has no vals in this container + const std::unordered_set& valsOwnedBy(Fusion* fusion) const; + + //! Get Exprs owned by a specific Fusion + //! Returns empty set if Fusion has no exprs in this container + const std::unordered_set& exprsOwnedBy(Fusion* fusion) const; + + //! Transfer statement ownership tracking from one Fusion to another + //! Used during move operations + void transferStatementOwnership(Fusion* from, Fusion* to); + + //! Public version of removeStatementsOwnedBy (acquires lock) + //! Removes all Statements owned by a specific Fusion + void removeStatementsOwnedBy(Fusion* fusion); + protected: // Mutex for thread-safe access when container is shared between Fusions // mutable because we need to lock in const methods @@ -121,6 +141,13 @@ class IrContainer { //! Fusions that share this container (for Phase 2 shared_ptr ownership) std::unordered_set sharing_fusions_; + //! Per-Fusion statement tracking for efficient ownership queries + //! Maps each Fusion to the set of Vals it owns in this container + std::unordered_map> per_fusion_vals_; + + //! Maps each Fusion to the set of Exprs it owns in this container + std::unordered_map> per_fusion_exprs_; + //! Remove all Statements owned by a specific Fusion (internal helper) //! Caller must hold unique_lock on mutex_ void removeStatementsOwnedByUnlocked(Fusion* fusion); diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp index a8d08a242f8..5d01f306d9f 100644 --- a/tests/cpp/test_phase2_container_sharing.cpp +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -199,4 +199,221 @@ TEST_F(Phase2ContainerTest, StatementCleanup) { EXPECT_EQ(container.unordered_exprs().size(), 0); } +// ============================================================================= +// Task 4 Tests: Per-Fusion Statement Tracking +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionValsTracking) { + // Test that ownedVals() returns only this Fusion's vals + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // ownedVals() should return only this Fusion's vals + const auto& owned_vals = fusion.ownedVals(); + EXPECT_GT(owned_vals.size(), 0); + + // All vals in ownedVals() should have container() == &fusion + for (auto* val : owned_vals) { + EXPECT_EQ(val->container(), &fusion); + } + + // vals() and ownedVals() should be the same with a single Fusion (Phase 1 + // equivalence) + EXPECT_EQ(fusion.vals().size(), fusion.ownedVals().size()); +} + +TEST_F(Phase2ContainerTest, PerFusionExprsTracking) { + // Test that ownedExprs() returns only this Fusion's exprs + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // ownedExprs() should return only this Fusion's exprs + const auto& owned_exprs = fusion.ownedExprs(); + EXPECT_GT(owned_exprs.size(), 0); + + // All exprs in ownedExprs() should have container() == &fusion + for (auto* expr : owned_exprs) { + EXPECT_EQ(expr->container(), &fusion); + } + + // unordered_exprs() and ownedExprs() should be the same with a single Fusion + EXPECT_EQ(fusion.unordered_exprs().size(), fusion.ownedExprs().size()); +} + +TEST_F(Phase2ContainerTest, ValsOwnedByAPI) { + // Test IrContainer::valsOwnedBy() API directly + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // valsOwnedBy should return same set as ownedVals() + const auto& vals_by_container = container.valsOwnedBy(&fusion); + const auto& vals_by_fusion = fusion.ownedVals(); + EXPECT_EQ(vals_by_container.size(), vals_by_fusion.size()); + + // valsOwnedBy for a non-registered Fusion should return empty set + Fusion other_fusion; + const auto& other_vals = container.valsOwnedBy(&other_fusion); + EXPECT_EQ(other_vals.size(), 0); +} + +TEST_F(Phase2ContainerTest, ExprsOwnedByAPI) { + // Test IrContainer::exprsOwnedBy() API directly + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // exprsOwnedBy should return same set as ownedExprs() + const auto& exprs_by_container = container.exprsOwnedBy(&fusion); + const auto& exprs_by_fusion = fusion.ownedExprs(); + EXPECT_EQ(exprs_by_container.size(), exprs_by_fusion.size()); + + // exprsOwnedBy for a non-registered Fusion should return empty set + Fusion other_fusion; + const auto& other_exprs = container.exprsOwnedBy(&other_fusion); + EXPECT_EQ(other_exprs.size(), 0); +} + +TEST_F(Phase2ContainerTest, RegisterUpdatesPerFusionTracking) { + // Test that registering new vals/exprs updates per-Fusion tracking + Fusion fusion; + FusionGuard fg(&fusion); + + // Initially no vals + EXPECT_EQ(fusion.ownedVals().size(), 0); + EXPECT_EQ(fusion.ownedExprs().size(), 0); + + // Add an input - this creates vals + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // Now we should have vals tracked for this fusion + size_t vals_after_input = fusion.ownedVals().size(); + EXPECT_GT(vals_after_input, 0); + + // Add an expression - this creates more vals and exprs + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Both should have grown + EXPECT_GT(fusion.ownedVals().size(), vals_after_input); + EXPECT_GT(fusion.ownedExprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, TransferStatementOwnership) { + // Test IrContainer::transferStatementOwnership + auto container = std::make_shared(); + + // Create dummy Fusions for testing + Fusion fusion1; + Fusion fusion2; + + // We can't easily create vals owned by fusion1 in a standalone container, + // but we can test the tracking data structure directly + container->addFusion(&fusion1); + container->addFusion(&fusion2); + + // Transfer ownership - should not crash even with empty tracking + container->transferStatementOwnership(&fusion1, &fusion2); + + // Verify fusion1 no longer has tracking entries (empty case) + EXPECT_EQ(container->valsOwnedBy(&fusion1).size(), 0); + EXPECT_EQ(container->exprsOwnedBy(&fusion1).size(), 0); + + // Cleanup + container->removeFusion(&fusion1); + container->removeFusion(&fusion2); +} + +TEST_F(Phase2ContainerTest, ClearOnlyAffectsOwnedStatements) { + // Test that Fusion::clear() only clears THIS Fusion's statements + // This is critical for shared container correctness + + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + // Get container reference + auto container_ptr = fusion.ir_container_ptr(); + + // Record counts before clear + size_t vals_before = fusion.ownedVals().size(); + size_t exprs_before = fusion.ownedExprs().size(); + EXPECT_GT(vals_before, 0); + EXPECT_GT(exprs_before, 0); + + // Clear the fusion + fusion.clear(); + + // After clear, ownedVals/ownedExprs should be empty for this fusion + EXPECT_EQ(fusion.ownedVals().size(), 0); + EXPECT_EQ(fusion.ownedExprs().size(), 0); + + // Container-level accessors should also reflect the removal + EXPECT_EQ(container_ptr->vals().size(), 0); + EXPECT_EQ(container_ptr->unordered_exprs().size(), 0); +} + +TEST_F(Phase2ContainerTest, RemoveStatementsOwnedByAPI) { + // Test public IrContainer::removeStatementsOwnedBy API + Fusion fusion; + FusionGuard fg(&fusion); + + // Create some IR + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto* tv1 = add(tv0, tv0); + fusion.addOutput(tv1); + + auto& container = *fusion.ir_container(); + + // Verify we have statements + EXPECT_GT(container.vals().size(), 0); + EXPECT_GT(container.unordered_exprs().size(), 0); + EXPECT_GT(container.valsOwnedBy(&fusion).size(), 0); + EXPECT_GT(container.exprsOwnedBy(&fusion).size(), 0); + + // Clear fusion-level state first (inputs_, outputs_, etc.) + // Note: We're testing the container API directly, not through Fusion::clear() + // In practice, Fusion::clear() does both + container.removeStatementsOwnedBy(&fusion); + + // After removal, tracking should be empty + EXPECT_EQ(container.valsOwnedBy(&fusion).size(), 0); + EXPECT_EQ(container.exprsOwnedBy(&fusion).size(), 0); + + // Container-level sets should also be empty (single fusion case) + EXPECT_EQ(container.vals().size(), 0); + EXPECT_EQ(container.unordered_exprs().size(), 0); +} + } // namespace nvfuser From 774994d209d7cb9a1349d31bfd28e13f9df2baac Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Feb 2026 20:24:05 -0800 Subject: [PATCH 7/9] Copy and Move Semantics with Shared Containers Implemented Phase 2 copy/move semantics where copy shares the container and move uses pointer-based swap. This enables efficient Fusion operations with shared IrContainers. 1. **Pointer-Based Swap** - Swap container shared_ptrs, not contents 2. **Copy Shares Container** - Copy constructor shares IrContainer with source 3. **Move Uses Swap** - Simple `Fusion() + swap` pattern from Phase 1 **Pointer-based swap:** - Collects owned statements before swap - Transfers Fusion registrations between containers - Swaps container pointers (not content!) - Swaps all Fusion-level members including axioms_/metadata_ - Updates statement ownership for only the swapped Fusions - Updates per-Fusion tracking in containers **New copy implementation:** - Creates IrCloner targeting destination Fusion directly - Clones only source's owned vals (not all vals in shared container) - Works correctly with shared containers **Copy constructor:** - Shares container pointer with source (no new container created) - Registers with shared container - Delegates to static copy method to clone nodes **Move assignment:** - Added self-assignment check - Added deprecation notes for `IrContainer::swap` and `IrContainer::copy` - Added per_fusion_vals_/per_fusion_exprs_ swapping - Updated `inContainer` to check `sharing_fusions_` instead of `parent_` Added 23 new tests for copy/move semantics: - Task 5 (Copy): CopySharesContainer, CopyRegistersWithContainer, CopiedNodesOwnedByNewFusion, CopyOwnedValsAreIndependent, etc. - Task 6 (Move): MoveConstructorTransfersOwnership, MoveUpdatesStatementOwnership, SwapExchangesContainers, SwapUpdatesStatementOwnership, etc. - Phase 2 Container Tests: 49/49 PASSED - Smoke Tests: 34/34 PASSED --- csrc/fusion.cpp | 147 ++++- csrc/ir/container.cpp | 34 +- tests/cpp/test_phase2_container_sharing.cpp | 695 ++++++++++++++++++++ 3 files changed, 838 insertions(+), 38 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 5bb12613e33..b959464dfe7 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -107,39 +107,58 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrContainer swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrContainer::swap(*(a.ir_container()), *(b.ir_container())); - - // Fix parent pointers after swapping containers - // After swap, each Fusion owns a different IrContainer, so we must - // update the parent backpointers in those containers to point to their new - // owners + if (&a == &b) { + return; + } + + // Phase 2: Pointer-based swap with ownership-filtered Statement updates + // This is more efficient than content swap and correctly handles shared + // containers by only updating statements owned by the swapped Fusions. + + // Step 1: Collect statements owned by each Fusion BEFORE swap + // We need to copy to vectors because we'll be modifying ownership + std::vector a_owned_vals, b_owned_vals; + std::vector a_owned_exprs, b_owned_exprs; + if (a.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - a.ir_container()->parent_ = &a; - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } + const auto& a_vals = a.ir_container_->valsOwnedBy(&a); + const auto& a_exprs = a.ir_container_->exprsOwnedBy(&a); + a_owned_vals.assign(a_vals.begin(), a_vals.end()); + a_owned_exprs.assign(a_exprs.begin(), a_exprs.end()); } + if (b.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - b.ir_container()->parent_ = &b; - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } + const auto& b_vals = b.ir_container_->valsOwnedBy(&b); + const auto& b_exprs = b.ir_container_->exprsOwnedBy(&b); + b_owned_vals.assign(b_vals.begin(), b_vals.end()); + b_owned_exprs.assign(b_exprs.begin(), b_exprs.end()); } + // Step 2: Handle registration transfers (before pointer swap) + // After swap, a should be registered with its new container (was b's) + // and b should be registered with its new container (was a's) + if (a.ir_container_ && b.ir_container_ && + a.ir_container_.get() != b.ir_container_.get()) { + // Different containers: swap registrations + // In a's container: remove a, add b (because b will own this container) + // In b's container: remove b, add a (because a will own this container) + a.ir_container_->transferFusion(&a, &b); + b.ir_container_->transferFusion(&b, &a); + } + + // Step 3: Swap container pointers (not content swap!) + std::swap(a.ir_container_, b.ir_container_); + + // Step 4: Swap all Fusion-level members std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); - std::swap(a.io_alias_, b.io_alias_); + std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + std::swap(a.is_during_update_uses_, b.is_during_update_uses_); + std::swap(a.managed_data_, b.managed_data_); + std::swap(a.managed_named_data_, b.managed_named_data_); + std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); // Swap per-Fusion special values (Phase 2) std::swap(a.zero_val_, b.zero_val_); @@ -147,6 +166,36 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { std::swap(a.true_val_, b.true_val_); std::swap(a.false_val_, b.false_val_); std::swap(a.magic_zero_val_, b.magic_zero_val_); + + // Swap per-Fusion axioms and metadata (Phase 2) + std::swap(a.axioms_, b.axioms_); + std::swap(a.metadata_, b.metadata_); + + // Step 5: Update statement ownership + // Statements that belonged to a now belong to b (they go with their data) + // Statements that belonged to b now belong to a + for (auto* val : a_owned_vals) { + val->ir_container_ = &b; + } + for (auto* expr : a_owned_exprs) { + expr->ir_container_ = &b; + } + for (auto* val : b_owned_vals) { + val->ir_container_ = &a; + } + for (auto* expr : b_owned_exprs) { + expr->ir_container_ = &a; + } + + // Step 6: Update per-Fusion tracking in containers + // After swap: a's new container needs to track a's new statements (was b's) + // b's new container needs to track b's new statements (was a's) + if (a.ir_container_) { + a.ir_container_->transferStatementOwnership(&b, &a); + } + if (b.ir_container_) { + b.ir_container_->transferStatementOwnership(&a, &b); + } } std::unique_ptr Fusion::segment( @@ -156,25 +205,43 @@ std::unique_ptr Fusion::segment( } IrCloner Fusion::copy(const Fusion* from, Fusion* to) { + FUSER_PERF_SCOPE("Fusion copy"); + + // Phase 2: Clear destination's state only (not entire container) + // to->clear() removes only 'to's statements from the shared container to->clear(); - auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); + // Phase 2: Create IrCloner targeting 'to' Fusion directly + // IrCloner sets cloned nodes' ir_container_ to 'to' + // This works with shared containers - clones go into the shared container + // but are tracked as owned by 'to' via per-Fusion tracking + IrCloner ir_cloner(to); + + // Phase 2: Clone only 'from's owned vals (not all vals in shared container) + // Use ownedVals() to get only vals belonging to 'from' + for (auto val : from->ownedVals()) { + ir_cloner.clone(val); + } - for (auto val : from->vals()) { + // Update definition_ and uses_ on cloned vals + for (auto val : from->ownedVals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } + // Clone fusion inputs to->inputs_ = ir_cloner.clone(from->inputs_); - to->outputs_ = ir_cloner.clone(from->outputs_); for (auto inp : to->inputs_) { inp->setIsFusionInput(true); } + + // Clone fusion outputs + to->outputs_ = ir_cloner.clone(from->outputs_); for (auto out : to->outputs_) { out->setIsFusionOutput(true); } - // TODO: put this into ir_cloner instead + // Clone io_alias mappings for (Val* out : from->outputs_) { const AliasInfo& alias = from->io_alias_.get(out); if (alias.type == AllocationType::New) { @@ -186,10 +253,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->io_alias_.add(copied_out, copied_in, alias.type, alias.visibility); } + // Copy other Fusion-level state to->all_tv_uses_valid_ = from->all_tv_uses_valid_; // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; + // Clone managed data for (const auto& i : from->managed_data_) { if (i.first.has_value()) { to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); @@ -208,6 +277,7 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + // Clone cached TV list if present if (from->all_tvs_ptr_ != nullptr) { to->all_tvs_ptr_ = std::make_unique>(); to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); @@ -225,9 +295,19 @@ Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->addFusion(this); // Register with container for Phase 2 } -// Copy constructor -Fusion::Fusion(const Fusion& other) : Fusion() { - FUSER_PERF_SCOPE("Fusion copy"); +// Copy constructor - Phase 2: Share container with source +Fusion::Fusion(const Fusion& other) + : ir_container_(other.ir_container_) { // Share container pointer + FUSER_PERF_SCOPE("Fusion copy ctor"); + + // Note: Special values are per-Fusion (Task 7), initialized to nullptr. + // They will be created lazily when accessed on the copy. + // Do NOT copy other.zero_val_ etc - each Fusion has its own. + + // Register with shared container + ir_container_->addFusion(this); + + // Delegate to static copy method to clone nodes Fusion::copy(&other, this); } @@ -247,6 +327,9 @@ Fusion& Fusion::operator=(const Fusion& other) { Fusion& Fusion::operator=(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move assign"); + if (this == &other) { + return *this; + } clear(); swap(*this, other); return *this; diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 8fe3c2f8692..9f17e649dfa 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -212,10 +212,13 @@ void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { } void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { - FUSER_PERF_SCOPE("Fusion swap"); + FUSER_PERF_SCOPE("IrContainer swap"); + + // NOTE: This method is deprecated in Phase 2. Fusion::swap handles + // pointer-based swapping of shared containers. This is kept for + // backward compatibility but should not be called directly. // Lock both containers in consistent order to avoid deadlock - // Use std::lock to lock both mutexes atomically std::unique_lock lock_a(a.mutex_, std::defer_lock); std::unique_lock lock_b(b.mutex_, std::defer_lock); std::lock(lock_a, lock_b); @@ -234,9 +237,16 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { // Note: Special values, axioms, and metadata are now per-Fusion, // not per-IrContainer. They are handled by Fusion::swap. + std::swap(a.sharing_fusions_, b.sharing_fusions_); + std::swap(a.per_fusion_vals_, b.per_fusion_vals_); + std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { + // NOTE: This method is deprecated in Phase 2. Fusion::copy handles + // copying with shared containers. This is kept for backward compatibility + // but should not be called directly. + // Lock both containers: shared for reading from, unique for writing to std::shared_lock lock_from(from->mutex_); std::unique_lock lock_to(to->mutex_); @@ -248,12 +258,18 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->exprs_up_.clear(); to->val_type_name_map_.clear(); to->expr_name_counter_ = 0; + to->per_fusion_vals_.clear(); + to->per_fusion_exprs_.clear(); + // NOTE: In Phase 2, we can't use to->parent() here because parent_ might + // not be set correctly for shared containers. Fusion::copy handles this. + NVF_ERROR( + to->parent_ != nullptr, + "IrContainer::copy requires parent_ to be set. Use Fusion::copy " + "instead."); IrCloner ir_cloner(to->parent()); // Copy values in deterministic order - // deterministic_vals can contain special values like one_val_, zero_val_, etc - // that are not registered in the container. std::deque from_vals; std::transform( from->vals_up_.begin(), @@ -411,9 +427,15 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return false; } + // Phase 2: With shared containers, multiple Fusions can share this container. + // The statement's container() returns its owning Fusion, which should be + // one of the Fusions sharing this container. + // Phase 1 (single Fusion): sharing_fusions_ == {parent_} + // Phase 2 (shared container): sharing_fusions_ contains multiple Fusions NVF_ERROR( - const_stmt->container() == this->parent(), - "Container claims to own stmt, but stmt disagrees."); + sharing_fusions_.count(const_stmt->container()) > 0, + "Container claims to own stmt, but stmt's owning Fusion is not " + "registered with this container."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* stmt = const_cast(const_stmt); diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp index 5d01f306d9f..c6e43b88c30 100644 --- a/tests/cpp/test_phase2_container_sharing.cpp +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -416,4 +416,699 @@ TEST_F(Phase2ContainerTest, RemoveStatementsOwnedByAPI) { EXPECT_EQ(container.unordered_exprs().size(), 0); } +// ============================================================================= +// Task 7 Tests: Per-Fusion Special Values +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionSpecialValuesBasic) { + // Test that special values are created per-Fusion + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + EXPECT_NE(zero_a, nullptr); + EXPECT_NE(one_a, nullptr); + EXPECT_EQ(zero_a->container(), &a); + EXPECT_EQ(one_a->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesOwnedByFusion) { + // Test that special values are tracked in ownedVals + Fusion a; + FusionGuard fg_a(&a); + + Val* zero_a = a.zeroVal(); + + // Special values should be in ownedVals + EXPECT_TRUE(a.ownedVals().count(zero_a) > 0); +} + +TEST_F(Phase2ContainerTest, SeparateFusionsHaveOwnSpecialValues) { + // Two independent Fusions should have different special values + Fusion a; + Fusion b; + + { + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + EXPECT_EQ(zero_a->container(), &a); + } + + { + FusionGuard fg_b(&b); + Val* zero_b = b.zeroVal(); + EXPECT_EQ(zero_b->container(), &b); + } + + // Each has its own zero (different objects) + EXPECT_NE(a.zeroVal(), b.zeroVal()); +} + +TEST_F(Phase2ContainerTest, DestroyFusionDoesNotAffectOther) { + // Destroying one Fusion should not affect another's special values + Fusion a; + FusionGuard fg_a(&a); + + // Create special values in a + Val* zero_a = a.zeroVal(); + EXPECT_NE(zero_a, nullptr); + + { + Fusion b; + FusionGuard fg_b(&b); + Val* zero_b = b.zeroVal(); + EXPECT_NE(zero_b, nullptr); + // b destroyed here + } + + // a should still work fine - its special values should still be valid + Val* zero_a_again = a.zeroVal(); + EXPECT_EQ(zero_a_again, zero_a); + EXPECT_EQ(zero_a_again->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesLazyCreation) { + // Special values should be created lazily + Fusion a; + FusionGuard fg_a(&a); + + // Before calling zeroVal(), it shouldn't exist + // (Can't directly test this, but we can verify it works after call) + Val* zero1 = a.zeroVal(); + Val* zero2 = a.zeroVal(); + + // Same value returned on repeated calls + EXPECT_EQ(zero1, zero2); +} + +TEST_F(Phase2ContainerTest, AllSpecialValuesPerFusion) { + // Test all special value accessors + Fusion a; + FusionGuard fg_a(&a); + + Val* zero = a.zeroVal(); + Val* one = a.oneVal(); + Val* true_val = a.trueVal(); + Val* false_val = a.falseVal(); + NamedScalar* magic_zero = a.magicZeroVal(); + + // All should be non-null + EXPECT_NE(zero, nullptr); + EXPECT_NE(one, nullptr); + EXPECT_NE(true_val, nullptr); + EXPECT_NE(false_val, nullptr); + EXPECT_NE(magic_zero, nullptr); + + // All should have container() == &a + EXPECT_EQ(zero->container(), &a); + EXPECT_EQ(one->container(), &a); + EXPECT_EQ(true_val->container(), &a); + EXPECT_EQ(false_val->container(), &a); + EXPECT_EQ(magic_zero->container(), &a); + + // All should be tracked in ownedVals + EXPECT_TRUE(a.ownedVals().count(zero) > 0); + EXPECT_TRUE(a.ownedVals().count(one) > 0); + EXPECT_TRUE(a.ownedVals().count(true_val) > 0); + EXPECT_TRUE(a.ownedVals().count(false_val) > 0); + EXPECT_TRUE(a.ownedVals().count(magic_zero) > 0); +} + +TEST_F(Phase2ContainerTest, SpecialValuesClearedOnFusionClear) { + // Test that Fusion::clear() resets special values + Fusion a; + FusionGuard fg_a(&a); + + // Create special values + Val* zero_before = a.zeroVal(); + Val* one_before = a.oneVal(); + EXPECT_NE(zero_before, nullptr); + EXPECT_NE(one_before, nullptr); + + // Clear the fusion + a.clear(); + + // Special values should be recreated lazily (new objects) + Val* zero_after = a.zeroVal(); + Val* one_after = a.oneVal(); + + // The new objects should be different from the old ones + // (old ones were removed by removeStatementsOwnedBy) + EXPECT_NE(zero_after, zero_before); + EXPECT_NE(one_after, one_before); + + // New objects should be valid and owned by the fusion + EXPECT_EQ(zero_after->container(), &a); + EXPECT_EQ(one_after->container(), &a); +} + +TEST_F(Phase2ContainerTest, SpecialValuesWithDtype) { + // Test zeroVal(dtype) and oneVal(dtype) accessors + Fusion a; + FusionGuard fg_a(&a); + + // Index type should return the cached value + Val* zero_index = a.zeroVal(DataType::Index); + Val* zero_cached = a.zeroVal(); + EXPECT_EQ(zero_index, zero_cached); + + Val* one_index = a.oneVal(DataType::Index); + Val* one_cached = a.oneVal(); + EXPECT_EQ(one_index, one_cached); + + // Bool type should return true/false val + Val* zero_bool = a.zeroVal(DataType::Bool); + Val* false_cached = a.falseVal(); + EXPECT_EQ(zero_bool, false_cached); + + Val* one_bool = a.oneVal(DataType::Bool); + Val* true_cached = a.trueVal(); + EXPECT_EQ(one_bool, true_cached); + + // Other types should create new values (not cached) + Val* zero_float = a.zeroVal(DataType::Float); + Val* zero_float2 = a.zeroVal(DataType::Float); + // These are not cached, so they're different objects + EXPECT_NE(zero_float, zero_float2); +} + +// ============================================================================= +// Task 5 Tests: Copy Semantics with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, CopySharesContainer) { + // After copy, both Fusions point to the same container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy + + // Both should share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); +} + +TEST_F(Phase2ContainerTest, CopyRegistersWithContainer) { + // sharingCount should increment after copy + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + EXPECT_EQ(a.ir_container()->sharingCount(), 1); + + Fusion b(a); + + EXPECT_EQ(a.ir_container()->sharingCount(), 2); + EXPECT_EQ(b.ir_container()->sharingCount(), 2); +} + +TEST_F(Phase2ContainerTest, CopiedNodesOwnedByNewFusion) { + // Cloned nodes should have container() == © + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + + // b should have inputs + EXPECT_EQ(b.inputs().size(), 1); + + // b's input should be owned by b (not a) + EXPECT_EQ(b.inputs()[0]->container(), &b); + + // b's input should be different from a's input (cloned) + EXPECT_NE(b.inputs()[0], a.inputs()[0]); +} + +TEST_F(Phase2ContainerTest, CopyOwnedValsAreIndependent) { + // a's ownedVals and b's ownedVals should be disjoint + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + + // All of a's ownedVals should have container() == &a + for (auto* v : a.ownedVals()) { + EXPECT_EQ(v->container(), &a); + } + + // All of b's ownedVals should have container() == &b + for (auto* v : b.ownedVals()) { + EXPECT_EQ(v->container(), &b); + } + + // The sets should be disjoint + for (auto* v : a.ownedVals()) { + EXPECT_EQ(b.ownedVals().count(v), 0); + } +} + +TEST_F(Phase2ContainerTest, DestructorOnlyRemovesOwnedStatements) { + // Destroying copy should not affect original's statements + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + size_t a_vals_before = a.ownedVals().size(); + + { + Fusion b(a); // Copy + // b gets its own cloned nodes + EXPECT_GT(b.ownedVals().size(), 0); + // b destroyed here + } + + // a's vals should still exist and be unchanged + EXPECT_EQ(a.ownedVals().size(), a_vals_before); + + // a's vals should still have correct container + for (auto* v : a.ownedVals()) { + EXPECT_EQ(v->container(), &a); + } +} + +TEST_F(Phase2ContainerTest, CopyHasOwnSpecialValues) { + // Each Fusion (original and copy) should have its own special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + Fusion b(a); // Copy + + // Copy should have its own special values + Val* zero_b = b.zeroVal(); + Val* one_b = b.oneVal(); + + // Different objects + EXPECT_NE(zero_a, zero_b); + EXPECT_NE(one_a, one_b); + + // Correct ownership + EXPECT_EQ(zero_a->container(), &a); + EXPECT_EQ(zero_b->container(), &b); +} + +TEST_F(Phase2ContainerTest, CopySpecialValuesIndependent) { + // Destroying copy should not affect original's special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + + { + Fusion b(a); // Copy + Val* zero_b = b.zeroVal(); + EXPECT_NE(zero_a, zero_b); + // b destroyed here + } + + // a's special values should still be valid + EXPECT_EQ(a.zeroVal(), zero_a); + EXPECT_EQ(zero_a->container(), &a); +} + +TEST_F(Phase2ContainerTest, CopySharingCountDecrementsOnDestruction) { + // When copy is destroyed, sharingCount should decrement + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto container_ptr = a.ir_container_ptr(); + EXPECT_EQ(container_ptr->sharingCount(), 1); + + { + Fusion b(a); + EXPECT_EQ(container_ptr->sharingCount(), 2); + // b destroyed here + } + + EXPECT_EQ(container_ptr->sharingCount(), 1); +} + +TEST_F(Phase2ContainerTest, CopyReturnsIrCloner) { + // Fusion::copy should return IrCloner for node mapping + // We test this indirectly via the copy constructor which uses Fusion::copy + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // Copy constructor uses Fusion::copy internally + Fusion b(a); + + // Verify the copy worked - b has cloned inputs/outputs + EXPECT_EQ(b.inputs().size(), a.inputs().size()); + EXPECT_EQ(b.outputs().size(), a.outputs().size()); + + // Cloned nodes should belong to b + EXPECT_EQ(b.inputs()[0]->container(), &b); + EXPECT_EQ(b.outputs()[0]->container(), &b); + + // They should be different objects from a's nodes + EXPECT_NE(b.inputs()[0], a.inputs()[0]); + EXPECT_NE(b.outputs()[0], a.outputs()[0]); +} + +// ============================================================================= +// Task 6 Tests: Move Semantics with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, MoveConstructorTransfersOwnership) { + // Move constructor should transfer container ownership + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto* container = a.ir_container_ptr().get(); + size_t a_vals_count = a.ownedVals().size(); + + Fusion b(std::move(a)); + + // b should have a's old container + EXPECT_EQ(b.ir_container_ptr().get(), container); + + // b should have a's statements + EXPECT_EQ(b.ownedVals().size(), a_vals_count); +} + +TEST_F(Phase2ContainerTest, MoveConstructorSourceIsValid) { + // After move, source should be valid with new empty container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + Fusion b(std::move(a)); + + // Source has new empty container (not nullptr) + EXPECT_NE(a.ir_container_ptr().get(), nullptr); + EXPECT_NE(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // Source is empty + EXPECT_EQ(a.ownedVals().size(), 0); + EXPECT_EQ(a.inputs().size(), 0); + EXPECT_EQ(a.outputs().size(), 0); + + // Source can still be used safely + a.clear(); // Should not crash +} + +TEST_F(Phase2ContainerTest, MoveUpdatesStatementOwnership) { + // Moved statements should have container() pointing to destination + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // Capture original vals + std::vector orig_vals(a.ownedVals().begin(), a.ownedVals().end()); + EXPECT_GT(orig_vals.size(), 0); + + Fusion b(std::move(a)); + + // All original vals now belong to b + for (auto* val : orig_vals) { + EXPECT_EQ(val->container(), &b); + } + + // b's ownedVals should contain them + for (auto* val : orig_vals) { + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MoveTransfersSpecialValues) { + // Move should transfer special value pointers to destination + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + Val* one_a = a.oneVal(); + + Fusion b(std::move(a)); + + // b should have a's special values + EXPECT_EQ(b.zeroVal(), zero_a); + EXPECT_EQ(b.oneVal(), one_a); + + // Ownership updated to b + EXPECT_EQ(zero_a->container(), &b); + EXPECT_EQ(one_a->container(), &b); +} + +TEST_F(Phase2ContainerTest, MoveSourceCanCreateNewSpecialValues) { + // After move, source can create new special values + Fusion a; + FusionGuard fg_a(&a); + Val* zero_a = a.zeroVal(); + + Fusion b(std::move(a)); + + // a is now empty but valid - can create new special values + Val* zero_a_new = a.zeroVal(); + + // Different from the moved one + EXPECT_NE(zero_a_new, zero_a); + EXPECT_EQ(zero_a_new->container(), &a); +} + +TEST_F(Phase2ContainerTest, MoveAssignmentWorks) { + // Move assignment should transfer ownership + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + + Fusion b; + b = std::move(a); + + // b has a's container + EXPECT_EQ(b.ir_container_ptr().get(), container); + + // a is valid but empty + EXPECT_NE(a.ir_container_ptr().get(), nullptr); + EXPECT_EQ(a.ownedVals().size(), 0); +} + +TEST_F(Phase2ContainerTest, MoveAssignmentSelfAssignment) { + // Self-assignment should be a no-op + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + size_t vals_count = a.ownedVals().size(); + + // Use a reference to avoid -Wself-move warning + Fusion& a_ref = a; + a = std::move(a_ref); + + // Should be unchanged + EXPECT_EQ(a.ir_container_ptr().get(), container); + EXPECT_EQ(a.ownedVals().size(), vals_count); +} + +TEST_F(Phase2ContainerTest, SwapExchangesContainers) { + // Swap should exchange container pointers + Fusion a, b; + + auto* container_a = a.ir_container_ptr().get(); + auto* container_b = b.ir_container_ptr().get(); + + Fusion::swap(a, b); + + EXPECT_EQ(a.ir_container_ptr().get(), container_b); + EXPECT_EQ(b.ir_container_ptr().get(), container_a); +} + +TEST_F(Phase2ContainerTest, SwapUpdatesStatementOwnership) { + // Swap should exchange statement ownership + Fusion a, b; + + { + FusionGuard fg_a(&a); + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + } + + { + FusionGuard fg_b(&b); + auto* tv0 = makeSymbolicTensor(3); + b.addInput(tv0); + } + + // Capture original vals + std::vector a_vals(a.ownedVals().begin(), a.ownedVals().end()); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + Fusion::swap(a, b); + + // a's old vals now belong to b + for (auto* val : a_vals) { + EXPECT_EQ(val->container(), &b); + } + + // b's old vals now belong to a + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &a); + } +} + +TEST_F(Phase2ContainerTest, SwapExchangesSpecialValues) { + // Swap should exchange special values + Fusion a, b; + + Val* zero_a = nullptr; + Val* zero_b = nullptr; + + { + FusionGuard fg_a(&a); + zero_a = a.zeroVal(); + } + + { + FusionGuard fg_b(&b); + zero_b = b.zeroVal(); + } + + Fusion::swap(a, b); + + // Special values exchanged + EXPECT_EQ(a.zeroVal(), zero_b); + EXPECT_EQ(b.zeroVal(), zero_a); + + // Ownership updated + EXPECT_EQ(zero_a->container(), &b); + EXPECT_EQ(zero_b->container(), &a); +} + +TEST_F(Phase2ContainerTest, SwapSelfSwapIsNoop) { + // Swapping with self should be a no-op + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + auto* container = a.ir_container_ptr().get(); + size_t vals_count = a.ownedVals().size(); + + Fusion::swap(a, a); + + EXPECT_EQ(a.ir_container_ptr().get(), container); + EXPECT_EQ(a.ownedVals().size(), vals_count); +} + +TEST_F(Phase2ContainerTest, MoveFromCopyPreservesOther) { + // If we copy A to B (sharing container), then move A to C, + // B should be unaffected + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Capture b's state + size_t b_vals_before = b.ownedVals().size(); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + Fusion c(std::move(a)); // Move a to c + + // b should be completely unaffected + EXPECT_EQ(b.ownedVals().size(), b_vals_before); + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &b); + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MoveFromCopyTransfersCorrectly) { + // If we copy A to B, then move A to C, + // C should have A's original statements + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + + // Capture a's vals before copy + std::vector a_vals(a.ownedVals().begin(), a.ownedVals().end()); + + Fusion b(a); // Copy + Fusion c(std::move(a)); // Move a to c + + // c should have a's original vals + for (auto* val : a_vals) { + EXPECT_EQ(val->container(), &c); + EXPECT_TRUE(c.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, MovePreservesInputsOutputs) { + // Move should transfer inputs/outputs vectors + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Val* orig_input = a.inputs()[0]; + Val* orig_output = a.outputs()[0]; + + Fusion b(std::move(a)); + + // b has the inputs/outputs + EXPECT_EQ(b.inputs().size(), 1); + EXPECT_EQ(b.outputs().size(), 1); + EXPECT_EQ(b.inputs()[0], orig_input); + EXPECT_EQ(b.outputs()[0], orig_output); + + // a is empty + EXPECT_EQ(a.inputs().size(), 0); + EXPECT_EQ(a.outputs().size(), 0); +} + } // namespace nvfuser From edff8afe0caee621900923513872a566a2dd50a0 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Wed, 4 Feb 2026 10:42:23 -0800 Subject: [PATCH 8/9] Per-Fusion Filtering for Deterministic Accessors and StatementGuard This commit fixes two issues discovered during Phase 2 shared container implementation: 1. Deterministic Accessors Not Filtering by Ownership - deterministic_vals(), deterministic_exprs(), and their map variants were returning ALL statements in the shared container instead of only those owned by the calling Fusion - Added deterministicValsOwnedBy(Fusion*) etc. to IrContainer - Updated Fusion methods to use filtered versions - Map variants now use local indices (0,1,2...) per Fusion 2. StatementGuard Incompatible with Shared Containers - removeStatementsCreatedAfter used per-Fusion counts but operated on container-level deques, causing incorrect removal - Also failed to update per_fusion_vals_/per_fusion_exprs_ tracking - Now takes Fusion* parameter and only removes owned statements - Properly updates per-Fusion tracking when removing Tests added: - 12 tests for deterministic accessor filtering - 2 tests for StatementGuard with shared containers Total Phase 2 tests: 63 passing Smoke tests: 34 passing --- csrc/fusion.h | 34 +- csrc/ir/container.cpp | 180 ++++++-- csrc/ir/container.h | 31 +- tests/cpp/test_phase2_container_sharing.cpp | 472 ++++++++++++++++++++ 4 files changed, 675 insertions(+), 42 deletions(-) diff --git a/csrc/fusion.h b/csrc/fusion.h index f4fb59be122..fc00d8564a7 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -523,31 +523,33 @@ class NVF_API Fusion : public PolymorphicBase { } // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_container()->deterministic_vals(); + // Phase 2: These return only statements owned by THIS Fusion, + // not all statements in the (possibly shared) container. + std::deque deterministic_vals() const noexcept { + return ir_container()->deterministicValsOwnedBy(const_cast(this)); } - const std::deque deterministic_exprs() const noexcept { - return ir_container()->deterministic_exprs(); + std::deque deterministic_exprs() const noexcept { + return ir_container()->deterministicExprsOwnedBy(const_cast(this)); } - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_container()->deterministic_vals_map(); + std::unordered_map deterministic_vals_map() const noexcept { + return ir_container()->deterministicValsMapOwnedBy( + const_cast(this)); } - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_container()->deterministic_exprs_map(); + std::unordered_map deterministic_exprs_map() const noexcept { + return ir_container()->deterministicExprsMapOwnedBy( + const_cast(this)); } // Collections access (unordered sets) const std::unordered_set& unordered_exprs() const noexcept { - return ir_container()->unordered_exprs(); + return ownedExprs(); } const std::unordered_set& vals() const noexcept { - return ir_container()->vals(); + return ownedVals(); } // Per-Fusion Statement Access (Phase 2 Task 4) @@ -573,11 +575,11 @@ class NVF_API Fusion : public PolymorphicBase { // Count queries int64_t numExprs() const noexcept { - return ir_container()->numExprs(); + return ownedExprs().size(); } int64_t numVals(bool include_shortcuts) const noexcept { - return ir_container()->numVals(include_shortcuts); + return ownedVals().size(); } // Shortcut values (frequently used constants) @@ -603,11 +605,13 @@ class NVF_API Fusion : public PolymorphicBase { void assumeNonNegative(Val* val); // Statement removal + // Phase 2: Now takes this Fusion as parameter to properly handle + // shared containers. Only removes statements owned by this Fusion. void removeStatementsCreatedAfter( int64_t num_exprs_before, int64_t num_vals_before) { ir_container()->removeStatementsCreatedAfter( - num_exprs_before, num_vals_before); + this, num_exprs_before, num_vals_before); } protected: diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 9f17e649dfa..655b4aaf4b2 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -72,6 +72,100 @@ const std::unordered_map IrContainer::deterministic_exprs_map() return exprs_map; } +// ========================================================================= +// Per-Fusion Deterministic Accessors (Phase 2) +// ========================================================================= + +std::deque IrContainer::deterministicValsOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::deque result; + + // Get the set of vals owned by this Fusion for O(1) lookup + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; // Empty - no vals owned by this Fusion + } + const auto& owned_vals = it->second; + + // Iterate in insertion order, filtering to only owned vals + for (const auto& val_up : vals_up_) { + Val* val = val_up.get(); + if (owned_vals.count(val) > 0) { + result.push_back(val); + } + } + return result; +} + +std::deque IrContainer::deterministicExprsOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::deque result; + + // Get the set of exprs owned by this Fusion for O(1) lookup + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; // Empty - no exprs owned by this Fusion + } + const auto& owned_exprs = it->second; + + // Iterate in insertion order, filtering to only owned exprs + for (const auto& expr_up : exprs_up_) { + Expr* expr = expr_up.get(); + if (owned_exprs.count(expr) > 0) { + result.push_back(expr); + } + } + return result; +} + +std::unordered_map IrContainer::deterministicValsMapOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::unordered_map result; + + // Get the set of vals owned by this Fusion for O(1) lookup + auto it = per_fusion_vals_.find(fusion); + if (it == per_fusion_vals_.end()) { + return result; // Empty - no vals owned by this Fusion + } + const auto& owned_vals = it->second; + + // Iterate in insertion order, assigning sequential ids to owned vals + int64_t count = 0; + for (const auto& val_up : vals_up_) { + Val* val = val_up.get(); + if (owned_vals.count(val) > 0) { + result[val] = count++; + } + } + return result; +} + +std::unordered_map IrContainer::deterministicExprsMapOwnedBy( + Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); + std::unordered_map result; + + // Get the set of exprs owned by this Fusion for O(1) lookup + auto it = per_fusion_exprs_.find(fusion); + if (it == per_fusion_exprs_.end()) { + return result; // Empty - no exprs owned by this Fusion + } + const auto& owned_exprs = it->second; + + // Iterate in insertion order, assigning sequential ids to owned exprs + int64_t count = 0; + for (const auto& expr_up : exprs_up_) { + Expr* expr = expr_up.get(); + if (owned_exprs.count(expr) > 0) { + result[expr] = count++; + } + } + return result; +} + const std::unordered_set& IrContainer::unordered_exprs() const noexcept { // Note: Returns reference - caller responsible for not holding across // concurrent modifications. Lock provides snapshot consistency during call. @@ -459,36 +553,72 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { // This avoids ownership conflicts when multiple Fusions share an IrContainer. void IrContainer::removeStatementsCreatedAfter( + Fusion* fusion, int64_t prev_num_exprs, int64_t prev_num_vals) { - NVF_ERROR( - exprs_up_.size() == exprs_.size(), - "exprs_up_ (size ", - exprs_up_.size(), - ") and exprs_ (size ", - exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(exprs_up_) >= prev_num_exprs, - "exprs_up_ size (", - std::ssize(exprs_up_), - ") is less than prev_num_exprs (", - prev_num_exprs, - ")."); - - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(exprs_up_) > prev_num_exprs) { - Expr* e = exprs_up_.back().get(); - for (Val* in : e->inputs()) { - in->removeUse(e); + std::unique_lock lock(mutex_); + + // Phase 2: Remove only statements owned by the specified Fusion. + // This correctly handles shared containers where multiple Fusions + // have statements interleaved in the container's deques. + + // Get current per-Fusion counts + auto vals_it = per_fusion_vals_.find(fusion); + auto exprs_it = per_fusion_exprs_.find(fusion); + + int64_t current_fusion_exprs = + (exprs_it != per_fusion_exprs_.end()) ? exprs_it->second.size() : 0; + int64_t current_fusion_vals = + (vals_it != per_fusion_vals_.end()) ? vals_it->second.size() : 0; + + // Calculate how many statements to remove from this Fusion + int64_t exprs_to_remove = current_fusion_exprs - prev_num_exprs; + int64_t vals_to_remove = current_fusion_vals - prev_num_vals; + + if (exprs_to_remove <= 0 && vals_to_remove <= 0) { + return; // Nothing to remove + } + + // Remove expressions owned by this Fusion (from back of deque) + // We iterate backwards and remove only those owned by this Fusion + int64_t exprs_removed = 0; + // Use index-based iteration to avoid iterator invalidation issues + for (int64_t i = static_cast(exprs_up_.size()) - 1; + i >= 0 && exprs_removed < exprs_to_remove; + --i) { + Expr* e = exprs_up_[i].get(); + if (e->container() == fusion) { + // Clean up use-def chains + for (Val* in : e->inputs()) { + in->removeUse(e); + } + // Remove from tracking sets + exprs_.erase(e); + if (exprs_it != per_fusion_exprs_.end()) { + exprs_it->second.erase(e); + } + // Erase from deque + exprs_up_.erase(exprs_up_.begin() + i); + exprs_removed++; } - exprs_.erase(e); - exprs_up_.pop_back(); } - while (std::ssize(vals_up_) > prev_num_vals) { - vals_.erase(vals_up_.back().get()); - vals_up_.pop_back(); + // Remove vals owned by this Fusion (from back of deque) + int64_t vals_removed = 0; + for (int64_t i = static_cast(vals_up_.size()) - 1; + i >= 0 && vals_removed < vals_to_remove; + --i) { + Val* v = vals_up_[i].get(); + if (v->container() == fusion) { + // Remove from tracking sets + vals_.erase(v); + if (vals_it != per_fusion_vals_.end()) { + vals_it->second.erase(v); + } + // Erase from deque + vals_up_.erase(vals_up_.begin() + i); + vals_removed++; + } } } diff --git a/csrc/ir/container.h b/csrc/ir/container.h index ccdd25f9242..acbbaab8707 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -31,7 +31,7 @@ class NamedScalar; class IrContainer { public: - NVF_API IrContainer(); + IrContainer(); // Copy/Move Constructors and Operators are deleted. IrContainer is managed // through a smart pointer in IrContainer. Semantic operations for Fusion @@ -65,6 +65,27 @@ class IrContainer { const std::unordered_map deterministic_exprs_map() const noexcept; + // ========================================================================= + // Per-Fusion Deterministic Accessors (Phase 2) + // These return statements in insertion order, filtered by ownership. + // ========================================================================= + + //! Return values owned by a specific Fusion in insertion order + std::deque deterministicValsOwnedBy(Fusion* fusion) const noexcept; + + //! Return expressions owned by a specific Fusion in insertion order + std::deque deterministicExprsOwnedBy(Fusion* fusion) const noexcept; + + //! Return mapping from value to integer id for values owned by a Fusion + //! The integer ids are local to this Fusion's values (0, 1, 2, ...) + std::unordered_map deterministicValsMapOwnedBy( + Fusion* fusion) const noexcept; + + //! Return mapping from expression to integer id for exprs owned by a Fusion + //! The integer ids are local to this Fusion's exprs (0, 1, 2, ...) + std::unordered_map deterministicExprsMapOwnedBy( + Fusion* fusion) const noexcept; + //! Return the set of Exprs registered with this fusion. Warning: This will //! return exprs outside inputs/outputs, so can be unsafe for use with //! segmented fusions. @@ -119,7 +140,7 @@ class IrContainer { //! Get Vals owned by a specific Fusion //! Returns empty set if Fusion has no vals in this container - const std::unordered_set& valsOwnedBy(Fusion* fusion) const; + NVF_API const std::unordered_set& valsOwnedBy(Fusion* fusion) const; //! Get Exprs owned by a specific Fusion //! Returns empty set if Fusion has no exprs in this container @@ -192,7 +213,13 @@ class IrContainer { // itself. // // Used by StatementGuard only. + // + // Phase 2 Note: This method now takes a Fusion pointer to properly handle + // shared containers. It removes only statements owned by the specified + // Fusion that were created after the snapshot point, preserving statements + // owned by other Fusions sharing the container. void removeStatementsCreatedAfter( + Fusion* fusion, int64_t prev_num_exprs, int64_t prev_num_vals); diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp index c6e43b88c30..0c0a86e0467 100644 --- a/tests/cpp/test_phase2_container_sharing.cpp +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -15,6 +15,7 @@ #include "fusion.h" #include "ir/container.h" #include "ops/all_ops.h" +#include "statement_guard.h" #include "tests/cpp/utils.h" namespace nvfuser { @@ -1111,4 +1112,475 @@ TEST_F(Phase2ContainerTest, MovePreservesInputsOutputs) { EXPECT_EQ(a.outputs().size(), 0); } +// ============================================================================= +// Deterministic Accessor Tests: Per-Fusion Filtering +// ============================================================================= + +TEST_F(Phase2ContainerTest, DeterministicValsReturnsOnlyOwned) { + // With a single Fusion, deterministic_vals() should return all vals + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // deterministic_vals() should return same count as ownedVals() + auto det_vals = a.deterministic_vals(); + EXPECT_EQ(det_vals.size(), a.ownedVals().size()); + + // All vals in deterministic_vals should be owned by a + for (auto* val : det_vals) { + EXPECT_EQ(val->container(), &a); + EXPECT_TRUE(a.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsReturnsOnlyOwned) { + // With a single Fusion, deterministic_exprs() should return all exprs + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // deterministic_exprs() should return same count as ownedExprs() + auto det_exprs = a.deterministic_exprs(); + EXPECT_EQ(det_exprs.size(), a.ownedExprs().size()); + + // All exprs in deterministic_exprs should be owned by a + for (auto* expr : det_exprs) { + EXPECT_EQ(expr->container(), &a); + EXPECT_TRUE(a.ownedExprs().count(expr) > 0); + } +} + +TEST_F( + Phase2ContainerTest, + DeterministicValsFiltersByOwnershipInSharedContainer) { + // After copy, each Fusion's deterministic_vals() returns only ITS vals + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Both share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // But each has its own deterministic vals + auto a_det_vals = a.deterministic_vals(); + auto b_det_vals = b.deterministic_vals(); + + // Sizes should match ownedVals + EXPECT_EQ(a_det_vals.size(), a.ownedVals().size()); + EXPECT_EQ(b_det_vals.size(), b.ownedVals().size()); + + // a's deterministic_vals should all be owned by a + for (auto* val : a_det_vals) { + EXPECT_EQ(val->container(), &a); + } + + // b's deterministic_vals should all be owned by b + for (auto* val : b_det_vals) { + EXPECT_EQ(val->container(), &b); + } + + // The sets should be disjoint + std::unordered_set a_set(a_det_vals.begin(), a_det_vals.end()); + for (auto* val : b_det_vals) { + EXPECT_EQ(a_set.count(val), 0); + } +} + +TEST_F( + Phase2ContainerTest, + DeterministicExprsFiltersByOwnershipInSharedContainer) { + // After copy, each Fusion's deterministic_exprs() returns only ITS exprs + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Both share the same container + EXPECT_EQ(a.ir_container_ptr().get(), b.ir_container_ptr().get()); + + // But each has its own deterministic exprs + auto a_det_exprs = a.deterministic_exprs(); + auto b_det_exprs = b.deterministic_exprs(); + + // Sizes should match ownedExprs + EXPECT_EQ(a_det_exprs.size(), a.ownedExprs().size()); + EXPECT_EQ(b_det_exprs.size(), b.ownedExprs().size()); + + // a's deterministic_exprs should all be owned by a + for (auto* expr : a_det_exprs) { + EXPECT_EQ(expr->container(), &a); + } + + // b's deterministic_exprs should all be owned by b + for (auto* expr : b_det_exprs) { + EXPECT_EQ(expr->container(), &b); + } + + // The sets should be disjoint + std::unordered_set a_set(a_det_exprs.begin(), a_det_exprs.end()); + for (auto* expr : b_det_exprs) { + EXPECT_EQ(a_set.count(expr), 0); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsMapFiltersByOwnership) { + // deterministic_vals_map should only include owned vals with local indices + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + auto a_map = a.deterministic_vals_map(); + auto b_map = b.deterministic_vals_map(); + + // Maps should have same size as ownedVals + EXPECT_EQ(a_map.size(), a.ownedVals().size()); + EXPECT_EQ(b_map.size(), b.ownedVals().size()); + + // All keys in a_map should be owned by a + for (const auto& [val, idx] : a_map) { + EXPECT_EQ(val->container(), &a); + } + + // All keys in b_map should be owned by b + for (const auto& [val, idx] : b_map) { + EXPECT_EQ(val->container(), &b); + } + + // Indices should be sequential starting from 0 (local to each Fusion) + std::vector a_indices, b_indices; + for (const auto& [val, idx] : a_map) { + a_indices.push_back(idx); + } + for (const auto& [val, idx] : b_map) { + b_indices.push_back(idx); + } + + std::sort(a_indices.begin(), a_indices.end()); + std::sort(b_indices.begin(), b_indices.end()); + + // Should be 0, 1, 2, ... for each + for (size_t i = 0; i < a_indices.size(); ++i) { + EXPECT_EQ(a_indices[i], static_cast(i)); + } + for (size_t i = 0; i < b_indices.size(); ++i) { + EXPECT_EQ(b_indices[i], static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsMapFiltersByOwnership) { + // deterministic_exprs_map should only include owned exprs with local indices + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + auto a_map = a.deterministic_exprs_map(); + auto b_map = b.deterministic_exprs_map(); + + // Maps should have same size as ownedExprs + EXPECT_EQ(a_map.size(), a.ownedExprs().size()); + EXPECT_EQ(b_map.size(), b.ownedExprs().size()); + + // All keys in a_map should be owned by a + for (const auto& [expr, idx] : a_map) { + EXPECT_EQ(expr->container(), &a); + } + + // All keys in b_map should be owned by b + for (const auto& [expr, idx] : b_map) { + EXPECT_EQ(expr->container(), &b); + } + + // Indices should be sequential starting from 0 + std::vector a_indices, b_indices; + for (const auto& [expr, idx] : a_map) { + a_indices.push_back(idx); + } + for (const auto& [expr, idx] : b_map) { + b_indices.push_back(idx); + } + + std::sort(a_indices.begin(), a_indices.end()); + std::sort(b_indices.begin(), b_indices.end()); + + for (size_t i = 0; i < a_indices.size(); ++i) { + EXPECT_EQ(a_indices[i], static_cast(i)); + } + for (size_t i = 0; i < b_indices.size(); ++i) { + EXPECT_EQ(b_indices[i], static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsMaintainsInsertionOrder) { + // deterministic_vals should maintain insertion order + Fusion a; + FusionGuard fg_a(&a); + + // Create multiple tensors in specific order + auto* tv0 = makeSymbolicTensor(1); + a.addInput(tv0); + auto* tv1 = makeSymbolicTensor(2); + a.addInput(tv1); + auto* tv2 = add(tv0, tv0); + auto* tv3 = add(tv1, tv1); + a.addOutput(tv2); + a.addOutput(tv3); + + auto det_vals = a.deterministic_vals(); + auto det_map = a.deterministic_vals_map(); + + // Verify deque order matches map indices + for (size_t i = 0; i < det_vals.size(); ++i) { + Val* val = det_vals[i]; + EXPECT_EQ(det_map.at(val), static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicExprsMaintainsInsertionOrder) { + // deterministic_exprs should maintain insertion order + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv0); + auto* tv3 = sub(tv2, tv1); + a.addOutput(tv3); + + auto det_exprs = a.deterministic_exprs(); + auto det_map = a.deterministic_exprs_map(); + + // Verify deque order matches map indices + for (size_t i = 0; i < det_exprs.size(); ++i) { + Expr* expr = det_exprs[i]; + EXPECT_EQ(det_map.at(expr), static_cast(i)); + } +} + +TEST_F(Phase2ContainerTest, DeterministicAccessorsAfterCopyPreservesOrder) { + // After copy, deterministic order for each Fusion should be correct + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Capture a's deterministic vals + auto a_det_vals_before = a.deterministic_vals(); + auto a_det_map_before = a.deterministic_vals_map(); + + Fusion b(a); // Copy + + // a's deterministic_vals should be unchanged + auto a_det_vals_after = a.deterministic_vals(); + EXPECT_EQ(a_det_vals_before.size(), a_det_vals_after.size()); + for (size_t i = 0; i < a_det_vals_before.size(); ++i) { + EXPECT_EQ(a_det_vals_before[i], a_det_vals_after[i]); + } + + // b's deterministic_vals should have same structure (but different objects) + auto b_det_vals = b.deterministic_vals(); + EXPECT_EQ(b_det_vals.size(), a_det_vals_before.size()); +} + +TEST_F(Phase2ContainerTest, DeterministicAccessorsAfterDestroyingCopy) { + // After destroying a copy, original's deterministic accessors still work + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto a_det_vals_before = a.deterministic_vals(); + auto a_det_map_before = a.deterministic_vals_map(); + + { + Fusion b(a); // Copy + // b destroyed here + } + + // a's deterministic accessors should still work correctly + auto a_det_vals_after = a.deterministic_vals(); + auto a_det_map_after = a.deterministic_vals_map(); + + EXPECT_EQ(a_det_vals_before.size(), a_det_vals_after.size()); + EXPECT_EQ(a_det_map_before.size(), a_det_map_after.size()); + + // Same values, same order + for (size_t i = 0; i < a_det_vals_before.size(); ++i) { + EXPECT_EQ(a_det_vals_before[i], a_det_vals_after[i]); + } +} + +TEST_F(Phase2ContainerTest, DeterministicValsEmptyForNewFusion) { + // New empty Fusion should have empty deterministic vals + Fusion a; + + auto det_vals = a.deterministic_vals(); + auto det_exprs = a.deterministic_exprs(); + auto det_vals_map = a.deterministic_vals_map(); + auto det_exprs_map = a.deterministic_exprs_map(); + + EXPECT_EQ(det_vals.size(), 0); + EXPECT_EQ(det_exprs.size(), 0); + EXPECT_EQ(det_vals_map.size(), 0); + EXPECT_EQ(det_exprs_map.size(), 0); +} + +TEST_F(Phase2ContainerTest, DeterministicValsAfterClear) { + // After clear, deterministic vals should be empty + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + EXPECT_GT(a.deterministic_vals().size(), 0); + EXPECT_GT(a.deterministic_exprs().size(), 0); + + a.clear(); + + EXPECT_EQ(a.deterministic_vals().size(), 0); + EXPECT_EQ(a.deterministic_exprs().size(), 0); + EXPECT_EQ(a.deterministic_vals_map().size(), 0); + EXPECT_EQ(a.deterministic_exprs_map().size(), 0); +} + +// ============================================================================= +// StatementGuard Tests with Shared Containers +// ============================================================================= + +TEST_F(Phase2ContainerTest, StatementGuardWithSharedContainer) { + // Test that StatementGuard works correctly with shared containers + // Bug: StatementGuard uses per-Fusion counts but removes from container + // deques + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); // Copy - shares container + + // Capture b's state before StatementGuard on a + size_t b_vals_before = b.ownedVals().size(); + size_t b_exprs_before = b.ownedExprs().size(); + std::vector b_vals(b.ownedVals().begin(), b.ownedVals().end()); + + { + FusionGuard fg_inner(&a); + StatementGuard sg(&a); + + // Create temporary vals in a + auto* temp = add(tv1, tv1); + (void)temp; + + // a has more vals now + EXPECT_GT(a.ownedVals().size(), b_vals_before); + } + // StatementGuard destructor should only remove a's new vals, not b's + + // b should be completely unaffected + EXPECT_EQ(b.ownedVals().size(), b_vals_before); + EXPECT_EQ(b.ownedExprs().size(), b_exprs_before); + + // b's vals should still have correct container + for (auto* val : b_vals) { + EXPECT_EQ(val->container(), &b); + EXPECT_TRUE(b.ownedVals().count(val) > 0); + } +} + +TEST_F(Phase2ContainerTest, StatementGuardDoesNotAffectOtherFusion) { + // StatementGuard on one Fusion should not affect another sharing the + // container + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + size_t a_vals_before_copy = a.ownedVals().size(); + + Fusion b(a); // Copy - shares container + + // Both should have same number of vals (cloned) + EXPECT_EQ(a_vals_before_copy, b.ownedVals().size()); + + size_t b_vals_at_guard_start = 0; + size_t b_vals_in_guard = 0; + + // Use StatementGuard on b to create and remove temp statements + { + FusionGuard fg_b(&b); + StatementGuard sg(&b); + + // Note: StatementGuard constructor calls axioms() which may create + // additional vals. The snapshot is taken AFTER axioms initialization. + b_vals_at_guard_start = b.ownedVals().size(); + + // Create temp vals in b + auto* b_input = b.inputs()[0]->as(); + auto* temp = mul(b_input, b_input); + (void)temp; + + b_vals_in_guard = b.ownedVals().size(); + + // b should have more vals now (from the temp operation) + EXPECT_GT(b_vals_in_guard, b_vals_at_guard_start); + } + + size_t a_vals_after = a.ownedVals().size(); + size_t b_vals_after = b.ownedVals().size(); + + // After guard, a should be unchanged + EXPECT_EQ(a_vals_after, a_vals_before_copy); + + // b should be back to its state at guard construction time + // (which includes axioms but not the temp vals created inside the guard) + EXPECT_EQ(b_vals_after, b_vals_at_guard_start); +} + } // namespace nvfuser From 56cf217cb81390b829e518118fd990d35d11e174 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 10 Feb 2026 12:47:38 -0800 Subject: [PATCH 9/9] Per-Fusion name counters for shared container name correspondence Replace global IrContainer name counters with per-Fusion counters so cloned Fusions produce matching statement names (T0=T0, T1=T1) instead of incrementing names (T0=T10). This fixes cross-fusion name lookups in GreedyParams and normalization_utils which use tv->name() as map keys. Changes: - Add per_fusion_val_name_map_ and per_fusion_expr_name_counter_ to IrContainer - Update getValName/getExprName to use per-Fusion counter with global fallback - Update registerVal/registerExpr to pass owning Fusion to name generators - Handle counter lifecycle in swap, copy, clear, destroy, transferOwnership - Use deterministic_vals() in Fusion::copy for stable clone ordering - Add 8 new tests for name correspondence (71/71 Phase 2 tests pass) --- csrc/fusion.cpp | 9 +- csrc/fusion.h | 8 +- csrc/ir/container.cpp | 54 ++++- csrc/ir/container.h | 46 +++- tests/cpp/test_phase2_container_sharing.cpp | 233 ++++++++++++++++++++ 5 files changed, 334 insertions(+), 16 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b959464dfe7..291cfbf70fb 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -218,13 +218,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { IrCloner ir_cloner(to); // Phase 2: Clone only 'from's owned vals (not all vals in shared container) - // Use ownedVals() to get only vals belonging to 'from' - for (auto val : from->ownedVals()) { + // CRITICAL: Use deterministic_vals() to get vals in insertion order. + // Using ownedVals() (unordered_set) causes non-deterministic clone order, + // which assigns different name() values to cloned vals between runs. + // This breaks code that uses tv->name() as map keys (e.g., GreedyParams). + for (auto val : from->deterministic_vals()) { ir_cloner.clone(val); } // Update definition_ and uses_ on cloned vals - for (auto val : from->ownedVals()) { + for (auto val : from->deterministic_vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } diff --git a/csrc/fusion.h b/csrc/fusion.h index fc00d8564a7..886191607c6 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -744,7 +744,13 @@ T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - if (src_container != dest_container) { + // Phase 2 Task 10: For same-container cloning (shared IrContainer), + // per-Fusion name counters produce matching names naturally (both start + // at 0), so the name override below is NOT needed and is skipped. + // For cross-container cloning (different IrContainers), we still need + // to force the source name since the destination's global counter may + // have diverged. + if (src_container->ir_container() != dest_container->ir_container()) { dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); } diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 655b4aaf4b2..2ddf9b9d2d5 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -268,6 +268,24 @@ void IrContainer::transferStatementOwnership(Fusion* from, Fusion* to) { to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end()); per_fusion_exprs_.erase(exprs_it); } + + // Transfer per-Fusion name counters (Phase 2 Task 10) + auto val_names_it = per_fusion_val_name_map_.find(from); + if (val_names_it != per_fusion_val_name_map_.end()) { + // Merge counter maps: take max of each ValType counter + auto& to_map = per_fusion_val_name_map_[to]; + for (auto& [vtype, counter] : val_names_it->second) { + to_map[vtype] = std::max(to_map[vtype], counter); + } + per_fusion_val_name_map_.erase(val_names_it); + } + + auto expr_names_it = per_fusion_expr_name_counter_.find(from); + if (expr_names_it != per_fusion_expr_name_counter_.end()) { + auto& to_counter = per_fusion_expr_name_counter_[to]; + to_counter = std::max(to_counter, expr_names_it->second); + per_fusion_expr_name_counter_.erase(expr_names_it); + } } void IrContainer::removeStatementsOwnedBy(Fusion* fusion) { @@ -303,6 +321,10 @@ void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) { // Clean up per-Fusion tracking (Phase 2 Task 4) per_fusion_vals_.erase(fusion); per_fusion_exprs_.erase(fusion); + + // Clean up per-Fusion name counters (Phase 2 Task 10) + per_fusion_val_name_map_.erase(fusion); + per_fusion_expr_name_counter_.erase(fusion); } void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { @@ -334,6 +356,10 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.sharing_fusions_, b.sharing_fusions_); std::swap(a.per_fusion_vals_, b.per_fusion_vals_); std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); + + // Swap per-Fusion name counters (Phase 2 Task 10) + std::swap(a.per_fusion_val_name_map_, b.per_fusion_val_name_map_); + std::swap(a.per_fusion_expr_name_counter_, b.per_fusion_expr_name_counter_); } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { @@ -354,6 +380,8 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->expr_name_counter_ = 0; to->per_fusion_vals_.clear(); to->per_fusion_exprs_.clear(); + to->per_fusion_val_name_map_.clear(); + to->per_fusion_expr_name_counter_.clear(); // NOTE: In Phase 2, we can't use to->parent() here because parent_ might // not be set correctly for shared containers. Fusion::copy handles this. @@ -468,12 +496,16 @@ void IrContainer::registerVal(Val* val) { // Otherwise handle registration locally vals_up_.emplace_back(val); vals_.insert(val); - val->setName(IrContainerPasskey(), getValName(val->vtype())); + + // Phase 2 Task 10: Use per-Fusion counter if val has an owning Fusion. + // This ensures cloned Fusions get matching names (T0=T0, T1=T1) + // instead of incrementing global names (T0=T10, T1=T11). + Fusion* owning_fusion = val->container(); + val->setName(IrContainerPasskey(), getValName(owning_fusion, val->vtype())); // Track per-Fusion ownership (Phase 2 Task 4) - // val->container() returns the owning Fusion - if (val->container() != nullptr) { - per_fusion_vals_[val->container()].insert(val); + if (owning_fusion != nullptr) { + per_fusion_vals_[owning_fusion].insert(val); } } @@ -486,12 +518,14 @@ void IrContainer::registerExpr(Expr* expr) { // Otherwise handle registration locally exprs_up_.emplace_back(expr); exprs_.insert(expr); - expr->setName(IrContainerPasskey(), getExprName()); + + // Phase 2 Task 10: Use per-Fusion counter if expr has an owning Fusion. + Fusion* owning_fusion = expr->container(); + expr->setName(IrContainerPasskey(), getExprName(owning_fusion)); // Track per-Fusion ownership (Phase 2 Task 4) - // expr->container() returns the owning Fusion - if (expr->container() != nullptr) { - per_fusion_exprs_[expr->container()].insert(expr); + if (owning_fusion != nullptr) { + per_fusion_exprs_[owning_fusion].insert(expr); } } @@ -507,6 +541,10 @@ void IrContainer::clear() noexcept { // Clear per-Fusion tracking (Phase 2 Task 4) per_fusion_vals_.clear(); per_fusion_exprs_.clear(); + + // Clear per-Fusion name counters (Phase 2 Task 10) + per_fusion_val_name_map_.clear(); + per_fusion_expr_name_counter_.clear(); } bool IrContainer::inContainer(const Statement* const_stmt) const { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index acbbaab8707..6a5666994ea 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -192,14 +192,43 @@ class IrContainer { //! Register expr with this container. NVF_API void registerExpr(Expr* expr); - StmtNameType getValName(ValType vtype) { + //! Get next val name, using per-Fusion counter if fusion is non-null, + //! falling back to global counter otherwise. + //! Per-Fusion counters ensure cloned Fusions produce matching names. + StmtNameType getValName(Fusion* fusion, ValType vtype) { + if (fusion != nullptr) { + auto& name_map = per_fusion_val_name_map_[fusion]; + if (name_map.find(vtype) == name_map.end()) { + name_map[vtype] = 0; + } + // Also advance global counter to keep it >= all per-Fusion counters + // This prevents conflicts if global counter is used later + auto& global = val_type_name_map_[vtype]; + auto per_fusion_name = name_map[vtype]++; + if (global <= per_fusion_name) { + global = per_fusion_name + 1; + } + return per_fusion_name; + } + // Global fallback for non-Fusion contexts if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { val_type_name_map_[vtype] = 0; } return val_type_name_map_[vtype]++; } - StmtNameType getExprName() { + //! Get next expr name, using per-Fusion counter if fusion is non-null, + //! falling back to global counter otherwise. + StmtNameType getExprName(Fusion* fusion) { + if (fusion != nullptr) { + auto& counter = per_fusion_expr_name_counter_[fusion]; + auto per_fusion_name = counter++; + // Also advance global counter + if (expr_name_counter_ <= per_fusion_name) { + expr_name_counter_ = per_fusion_name + 1; + } + return per_fusion_name; + } return expr_name_counter_++; } @@ -237,12 +266,21 @@ class IrContainer { // something like check if an Expr is in this container std::unordered_set exprs_; - // Values names counters + // Values names counters (global fallback for non-Fusion contexts) std::unordered_map val_type_name_map_; - // Expression names counter + // Expression names counter (global fallback for non-Fusion contexts) StmtNameType expr_name_counter_ = 0; + // Per-Fusion name counters (Phase 2 Task 10) + // Each Fusion gets its own counter starting at 0, so cloned Fusions + // produce matching names (T0=T0, T1=T1) instead of incrementing names. + // This is critical for GreedyParams and normalization_utils which use + // tv->name() as map keys across cloned Fusions. + std::unordered_map> + per_fusion_val_name_map_; + std::unordered_map per_fusion_expr_name_counter_; + // Note: Special values (zero_val_, one_val_, true_val_, false_val_, // magic_zero_val_) are now per-Fusion, stored in Fusion class. // This avoids ownership conflicts when multiple Fusions share an IrContainer. diff --git a/tests/cpp/test_phase2_container_sharing.cpp b/tests/cpp/test_phase2_container_sharing.cpp index 0c0a86e0467..4261ec6e501 100644 --- a/tests/cpp/test_phase2_container_sharing.cpp +++ b/tests/cpp/test_phase2_container_sharing.cpp @@ -1583,4 +1583,237 @@ TEST_F(Phase2ContainerTest, StatementGuardDoesNotAffectOtherFusion) { EXPECT_EQ(b_vals_after, b_vals_at_guard_start); } +// ============================================================================= +// Task 10 Tests: Per-Fusion Name Counters +// ============================================================================= + +TEST_F(Phase2ContainerTest, PerFusionNameCountersBasic) { + // Each Fusion gets its own name counters starting at 0 + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + // First TensorView in Fusion a should have name 0 + EXPECT_EQ(tv0->name(), 0); +} + +TEST_F(Phase2ContainerTest, IndependentFusionsHaveOwnCounters) { + // Two independent Fusions both start name counters at 0 + Fusion a; + { + FusionGuard fg_a(&a); + auto* tv0_a = makeSymbolicTensor(2); + a.addInput(tv0_a); + auto* tv1_a = add(tv0_a, tv0_a); + a.addOutput(tv1_a); + // a's TensorViews should have names 0, 1 + EXPECT_EQ(tv0_a->name(), 0); + EXPECT_EQ(tv1_a->name(), 1); + } + + Fusion b; + { + FusionGuard fg_b(&b); + auto* tv0_b = makeSymbolicTensor(2); + b.addInput(tv0_b); + auto* tv1_b = add(tv0_b, tv0_b); + b.addOutput(tv1_b); + // b's TensorViews should ALSO have names 0, 1 (independent counter) + EXPECT_EQ(tv0_b->name(), 0); + EXPECT_EQ(tv1_b->name(), 1); + } +} + +TEST_F(Phase2ContainerTest, CopyNameCorrespondence) { + // CRITICAL: After Fusion::copy into shared container, cloned vals have + // matching names. This is required by GreedyParams and normalization_utils. + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Record a's val names + std::vector> a_val_names; + for (auto* val : a.deterministic_vals()) { + a_val_names.push_back({val, val->name()}); + } + + // Copy a -> b (shared container) + Fusion b(a); + + // b's vals should have MATCHING names (not incremented) + auto b_vals = b.deterministic_vals(); + EXPECT_EQ(b_vals.size(), a_val_names.size()); + + // Check TensorViews specifically - these are what GreedyParams uses + std::vector a_tvs, b_tvs; + for (auto* val : a.deterministic_vals()) { + if (val->isA()) { + a_tvs.push_back(val); + } + } + for (auto* val : b.deterministic_vals()) { + if (val->isA()) { + b_tvs.push_back(val); + } + } + + EXPECT_EQ(a_tvs.size(), b_tvs.size()); + for (size_t i = 0; i < a_tvs.size(); ++i) { + // Names should match across original and clone + EXPECT_EQ(a_tvs[i]->name(), b_tvs[i]->name()) + << "TV name mismatch at index " << i << ": a=" << a_tvs[i]->name() + << " b=" << b_tvs[i]->name(); + } +} + +TEST_F(Phase2ContainerTest, CopyExprNameCorrespondence) { + // After copy, cloned expressions should also have matching names + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + auto* tv2 = mul(tv1, tv1); + a.addOutput(tv2); + + // Record a's expr names + auto a_exprs = a.deterministic_exprs(); + std::vector a_expr_names; + for (auto* expr : a_exprs) { + a_expr_names.push_back(expr->name()); + } + + // Copy + Fusion b(a); + + auto b_exprs = b.deterministic_exprs(); + EXPECT_EQ(b_exprs.size(), a_expr_names.size()); + + for (size_t i = 0; i < a_expr_names.size(); ++i) { + EXPECT_EQ(b_exprs[i]->name(), a_expr_names[i]) + << "Expr name mismatch at index " << i; + } +} + +TEST_F(Phase2ContainerTest, MultipleCopiesHaveMatchingNames) { + // Multiple copies from the same source should all have matching names + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + Fusion b(a); + Fusion c(a); + + // b and c should both have matching TV names + auto a_tvs_det = a.deterministic_vals(); + auto b_tvs_det = b.deterministic_vals(); + auto c_tvs_det = c.deterministic_vals(); + + EXPECT_EQ(a_tvs_det.size(), b_tvs_det.size()); + EXPECT_EQ(a_tvs_det.size(), c_tvs_det.size()); + + for (size_t i = 0; i < a_tvs_det.size(); ++i) { + EXPECT_EQ(a_tvs_det[i]->name(), b_tvs_det[i]->name()); + EXPECT_EQ(a_tvs_det[i]->name(), c_tvs_det[i]->name()); + } +} + +TEST_F(Phase2ContainerTest, NameCountersCleanedUpOnDestroy) { + // When a Fusion is destroyed, its per-Fusion counters should be cleaned up + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + auto container_ptr = a.ir_container_ptr(); + + { + Fusion b(a); // Copy - shares container, creates per-Fusion counters + EXPECT_EQ(container_ptr->sharingCount(), 2); + } + // b destroyed: per-Fusion counters for b should be cleaned up + + EXPECT_EQ(container_ptr->sharingCount(), 1); + // a should still work fine + EXPECT_GT(a.ownedVals().size(), 0); +} + +TEST_F(Phase2ContainerTest, NameCountersSurviveSwap) { + // After swap, name counters should follow the data + Fusion a; + { + FusionGuard fg_a(&a); + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + } + + Fusion b; + { + FusionGuard fg_b(&b); + auto* tv0 = makeSymbolicTensor(3); + b.addInput(tv0); + auto* tv1 = mul(tv0, tv0); + b.addOutput(tv1); + } + + // Get TV names before swap + auto a_tv0_name = a.inputs()[0]->name(); + auto b_tv0_name = b.inputs()[0]->name(); + + Fusion::swap(a, b); + + // After swap, a has b's old data and vice versa + EXPECT_EQ(a.inputs()[0]->name(), b_tv0_name); + EXPECT_EQ(b.inputs()[0]->name(), a_tv0_name); + + // Adding new TVs after swap should work with correct counters + { + FusionGuard fg_a(&a); + auto* new_tv = + add(a.inputs()[0]->as(), a.inputs()[0]->as()); + // New TV should get a valid name (not crash) + EXPECT_GE(new_tv->name(), 0); + } +} + +TEST_F(Phase2ContainerTest, NameCountersAfterClearAndRebuild) { + // After Fusion::clear(), name counters should reset so new vals start at 0 + Fusion a; + FusionGuard fg_a(&a); + + auto* tv0 = makeSymbolicTensor(2); + a.addInput(tv0); + auto* tv1 = add(tv0, tv0); + a.addOutput(tv1); + + EXPECT_EQ(tv0->name(), 0); + EXPECT_EQ(tv1->name(), 1); + + a.clear(); + + // After clear, new vals should start at 0 again + auto* tv0_new = makeSymbolicTensor(2); + a.addInput(tv0_new); + EXPECT_EQ(tv0_new->name(), 0); +} + } // namespace nvfuser