From ed4b5e6429904b388a42f7749707e98b392ba350 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Thu, 12 Feb 2026 15:29:59 -0800 Subject: [PATCH] Copy/move/swap semantics for shared containers Copy constructor now shares the source's container pointer instead of creating a new one. Fusion::copy clones directly from per-Fusion filtered vals rather than delegating to IrContainer::copy. Swap changed from content-based (IrContainer::swap) to pointer-based with per-Fusion ownership tracking for both same-container and different-container cases. --- csrc/fusion.cpp | 120 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 84 insertions(+), 36 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 8f76a1b57d0..df1e8b03b68 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -107,46 +107,88 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrContainer swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrContainer::swap(*(a.ir_container()), *(b.ir_container())); + if (&a == &b) { + return; + } - // After swapping container contents, per-Fusion tracking keys point to the - // wrong Fusions. Rename: a's container had b's entries, b's had a's. - a.ir_container()->transferStatementOwnership(&b, &a); - b.ir_container()->transferStatementOwnership(&a, &b); + // Collect statements owned by each Fusion BEFORE swap + std::vector a_owned_vals, b_owned_vals; + std::vector a_owned_exprs, b_owned_exprs; if (a.ir_container_) { - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } + const auto& av = a.ir_container_->valsOwnedBy(&a); + const auto& ae = a.ir_container_->exprsOwnedBy(&a); + a_owned_vals.assign(av.begin(), av.end()); + a_owned_exprs.assign(ae.begin(), ae.end()); } if (b.ir_container_) { - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } + const auto& bv = b.ir_container_->valsOwnedBy(&b); + const auto& be = b.ir_container_->exprsOwnedBy(&b); + b_owned_vals.assign(bv.begin(), bv.end()); + b_owned_exprs.assign(be.begin(), be.end()); + } + + // Transfer Fusion registrations between containers before pointer swap. + // After swap, a will own b's container and b will own a's container. + if (a.ir_container_ && b.ir_container_ && + a.ir_container_.get() != b.ir_container_.get()) { + a.ir_container_->transferFusion(&a, &b); + b.ir_container_->transferFusion(&b, &a); } + // Swap container pointers + std::swap(a.ir_container_, b.ir_container_); + + // Swap all Fusion-level members std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); - std::swap(a.io_alias_, b.io_alias_); - - // Swap per-Fusion special values (Phase 2) + std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + std::swap(a.is_during_update_uses_, b.is_during_update_uses_); + std::swap(a.managed_data_, b.managed_data_); + std::swap(a.managed_named_data_, b.managed_named_data_); + std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_); std::swap(a.zero_val_, b.zero_val_); std::swap(a.one_val_, b.one_val_); std::swap(a.true_val_, b.true_val_); std::swap(a.false_val_, b.false_val_); std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); std::swap(a.metadata_, b.metadata_); + + // Update Statement::ir_container_ pointers: a's old statements now belong + // to b, and b's old statements now belong to a + for (auto* val : a_owned_vals) { + val->ir_container_ = &b; + } + for (auto* expr : a_owned_exprs) { + expr->ir_container_ = &b; + } + for (auto* val : b_owned_vals) { + val->ir_container_ = &a; + } + for (auto* expr : b_owned_exprs) { + expr->ir_container_ = &a; + } + + // Update per-Fusion tracking keys in containers + if (a.ir_container_ && b.ir_container_) { + if (a.ir_container_.get() == b.ir_container_.get()) { + // Same container: directly swap per-Fusion tracking entries + auto* c = a.ir_container_.get(); + std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]); + std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]); + } else { + // Different containers: rename tracking keys to match new owners + a.ir_container_->transferStatementOwnership(&b, &a); + b.ir_container_->transferStatementOwnership(&a, &b); + } + } else if (a.ir_container_) { + a.ir_container_->transferStatementOwnership(&b, &a); + } else if (b.ir_container_) { + b.ir_container_->transferStatementOwnership(&a, &b); + } } std::unique_ptr Fusion::segment( @@ -158,10 +200,20 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = - IrContainer::copy(from->ir_container(), to->ir_container(), to); + IrCloner ir_cloner(to); + + // Clone from's vals in insertion order + for (auto val : from->deterministic_vals()) { + ir_cloner.clone(val); + } - // Remap cached special val pointers through the cloner + // Wire up definitions and uses on cloned vals + for (auto val : from->vals()) { + ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); + ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + } + + // Remap cached special val pointers if (from->zero_val_) { to->zero_val_ = ir_cloner.clone(from->zero_val_); } @@ -179,11 +231,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { ir_cloner.clone(from->magic_zero_val_)->as(); } - for (auto val : from->vals()) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); - } - to->inputs_ = ir_cloner.clone(from->inputs_); to->outputs_ = ir_cloner.clone(from->outputs_); for (auto inp : to->inputs_) { @@ -193,7 +240,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { out->setIsFusionOutput(true); } - // TODO: put this into ir_cloner instead for (Val* out : from->outputs_) { const AliasInfo& alias = from->io_alias_.get(out); if (alias.type == AllocationType::New) { @@ -206,14 +252,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } to->all_tv_uses_valid_ = from->all_tv_uses_valid_; - // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; for (const auto& i : from->managed_data_) { if (i.first.has_value()) { to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second); } else { - // Don't clone managed data if it has been reset to->managed_data_.emplace_back(i.first, i.second); } } @@ -256,9 +300,10 @@ Fusion::Fusion() : ir_container_(std::make_shared()) { ir_container_->addFusion(this); } -// Copy constructor -Fusion::Fusion(const Fusion& other) : Fusion() { +// Copy constructor -- shares the source's container +Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) { FUSER_PERF_SCOPE("Fusion copy"); + ir_container_->addFusion(this); Fusion::copy(&other, this); } @@ -278,6 +323,9 @@ Fusion& Fusion::operator=(const Fusion& other) { Fusion& Fusion::operator=(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move assign"); + if (this == &other) { + return *this; + } clear(); swap(*this, other); return *this;