Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 84 additions & 36 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,46 +107,88 @@ bool Fusion::sameDefinition(const Fusion& other) const {
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
FUSER_PERF_SCOPE("Fusion swap");

// We need to be careful to call IrContainer swap not unique_ptr swap, which
// will only swap the ptrs NOT the contents.
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
if (&a == &b) {
return;
}

// After swapping container contents, per-Fusion tracking keys point to the
// wrong Fusions. Rename: a's container had b's entries, b's had a's.
a.ir_container()->transferStatementOwnership(&b, &a);
b.ir_container()->transferStatementOwnership(&a, &b);
// Collect statements owned by each Fusion BEFORE swap
std::vector<Val*> a_owned_vals, b_owned_vals;
std::vector<Expr*> a_owned_exprs, b_owned_exprs;

if (a.ir_container_) {
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
for (auto expr : a.deterministic_exprs()) {
expr->ir_container_ = &a;
}
const auto& av = a.ir_container_->valsOwnedBy(&a);
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
a_owned_vals.assign(av.begin(), av.end());
a_owned_exprs.assign(ae.begin(), ae.end());
}
if (b.ir_container_) {
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
for (auto expr : b.deterministic_exprs()) {
expr->ir_container_ = &b;
}
const auto& bv = b.ir_container_->valsOwnedBy(&b);
const auto& be = b.ir_container_->exprsOwnedBy(&b);
b_owned_vals.assign(bv.begin(), bv.end());
b_owned_exprs.assign(be.begin(), be.end());
}

// Transfer Fusion registrations between containers before pointer swap.
// After swap, a will own b's container and b will own a's container.
if (a.ir_container_ && b.ir_container_ &&
a.ir_container_.get() != b.ir_container_.get()) {
a.ir_container_->transferFusion(&a, &b);
b.ir_container_->transferFusion(&b, &a);
}

// Swap container pointers
std::swap(a.ir_container_, b.ir_container_);

// Swap all Fusion-level members
std::swap(a.inputs_, b.inputs_);
std::swap(a.outputs_, b.outputs_);

std::swap(a.io_alias_, b.io_alias_);

// Swap per-Fusion special values (Phase 2)
std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_);
std::swap(a.is_during_update_uses_, b.is_during_update_uses_);
std::swap(a.managed_data_, b.managed_data_);
std::swap(a.managed_named_data_, b.managed_named_data_);
std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_);
std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_);
std::swap(a.zero_val_, b.zero_val_);
std::swap(a.one_val_, b.one_val_);
std::swap(a.true_val_, b.true_val_);
std::swap(a.false_val_, b.false_val_);
std::swap(a.magic_zero_val_, b.magic_zero_val_);

std::swap(a.axioms_, b.axioms_);
std::swap(a.metadata_, b.metadata_);

// Update Statement::ir_container_ pointers: a's old statements now belong
// to b, and b's old statements now belong to a
for (auto* val : a_owned_vals) {
val->ir_container_ = &b;
}
for (auto* expr : a_owned_exprs) {
expr->ir_container_ = &b;
}
for (auto* val : b_owned_vals) {
val->ir_container_ = &a;
}
for (auto* expr : b_owned_exprs) {
expr->ir_container_ = &a;
}

// Update per-Fusion tracking keys in containers
if (a.ir_container_ && b.ir_container_) {
if (a.ir_container_.get() == b.ir_container_.get()) {
// Same container: directly swap per-Fusion tracking entries
auto* c = a.ir_container_.get();
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
} else {
// Different containers: rename tracking keys to match new owners
a.ir_container_->transferStatementOwnership(&b, &a);
b.ir_container_->transferStatementOwnership(&a, &b);
}
} else if (a.ir_container_) {
a.ir_container_->transferStatementOwnership(&b, &a);
} else if (b.ir_container_) {
b.ir_container_->transferStatementOwnership(&a, &b);
}
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand All @@ -158,10 +200,20 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner =
IrContainer::copy(from->ir_container(), to->ir_container(), to);
IrCloner ir_cloner(to);

// Clone from's vals in insertion order
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val);
}

// Remap cached special val pointers through the cloner
// Wire up definitions and uses on cloned vals
for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

// Remap cached special val pointers
if (from->zero_val_) {
to->zero_val_ = ir_cloner.clone(from->zero_val_);
}
Expand All @@ -179,11 +231,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>();
}

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

to->inputs_ = ir_cloner.clone(from->inputs_);
to->outputs_ = ir_cloner.clone(from->outputs_);
for (auto inp : to->inputs_) {
Expand All @@ -193,7 +240,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
out->setIsFusionOutput(true);
}

// TODO: put this into ir_cloner instead
for (Val* out : from->outputs_) {
const AliasInfo& alias = from->io_alias_.get(out);
if (alias.type == AllocationType::New) {
Expand All @@ -206,14 +252,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
// This should never be true on copy, but copying for completeness.
to->is_during_update_uses_ = from->is_during_update_uses_;

for (const auto& i : from->managed_data_) {
if (i.first.has_value()) {
to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second);
} else {
// Don't clone managed data if it has been reset
to->managed_data_.emplace_back(i.first, i.second);
}
}
Expand Down Expand Up @@ -256,9 +300,10 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
ir_container_->addFusion(this);
}

// Copy constructor
Fusion::Fusion(const Fusion& other) : Fusion() {
// Copy constructor -- shares the source's container
Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) {
FUSER_PERF_SCOPE("Fusion copy");
ir_container_->addFusion(this);
Fusion::copy(&other, this);
}

Expand All @@ -278,6 +323,9 @@ Fusion& Fusion::operator=(const Fusion& other) {

Fusion& Fusion::operator=(Fusion&& other) noexcept {
FUSER_PERF_SCOPE("Fusion move assign");
if (this == &other) {
return *this;
}
clear();
swap(*this, other);
return *this;
Expand Down
Loading