Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <fusion.h>

#include <type.h>
#include <iterator>
#include <ranges>

Expand All @@ -19,7 +20,9 @@
#include <host_ir/container.h>
#include <instrumentation.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <ir/cloner.h>
#include <ir/internal_nodes.h>
#include <ir/printer.h>
#include <ir/utils.h>
#include <iter_visitor.h>
Expand Down Expand Up @@ -137,6 +140,16 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
std::swap(a.outputs_, b.outputs_);

std::swap(a.io_alias_, b.io_alias_);

// Swap per-Fusion special values (Phase 2)
std::swap(a.zero_val_, b.zero_val_);
std::swap(a.one_val_, b.one_val_);
std::swap(a.true_val_, b.true_val_);
std::swap(a.false_val_, b.false_val_);
std::swap(a.magic_zero_val_, b.magic_zero_val_);

std::swap(a.axioms_, b.axioms_);
std::swap(a.metadata_, b.metadata_);
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand Down Expand Up @@ -198,6 +211,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {

to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_;

if (from->axioms_ != nullptr) {
to->axioms_ = std::make_unique<std::vector<Val*>>();
to->axioms_->reserve(from->axioms_->size());
for (auto pred : *from->axioms_) {
to->axioms_->push_back(ir_cloner.clone(pred));
}
}

for (auto& [key, val_expr] : from->metadata_) {
to->metadata_[ir_cloner.clone(key)] = std::make_pair(
ir_cloner.clone(val_expr.first), ir_cloner.clone(val_expr.second));
}

if (from->all_tvs_ptr_ != nullptr) {
to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size());
Expand Down Expand Up @@ -264,6 +290,18 @@ void Fusion::clear() noexcept {
managed_data_.clear();
managed_named_data_.clear();

// Reset per-Fusion special values (they'll be recreated lazily if needed).
// These unique_ptrs own the Val objects; ir_container()->clear() above only
// removed them from vals_ (they were already absent from vals_up_).
zero_val_.reset();
one_val_.reset();
true_val_.reset();
false_val_.reset();
magic_zero_val_.reset();

axioms_.reset();
metadata_.clear();

invalidateTvsAndUses();

is_during_update_uses_ = false;
Expand Down Expand Up @@ -297,6 +335,13 @@ void Fusion::removeExpr(Expr* expr) {
void Fusion::removeVal(Val* val) {
assertInContainer(val, "Cannot remove val ");

// Don't remove cached special vals — they are lazily created singletons
if (val == zero_val_.get() || val == one_val_.get() ||
val == true_val_.get() || val == false_val_.get() ||
val == magic_zero_val_.get()) {
return;
}

NVF_CHECK(
!val->isFusionInput(),
"Cannot remove val as it is an input of the fusion.");
Expand Down Expand Up @@ -689,6 +734,122 @@ void Fusion::printTransforms() {
t_exprs.handle(this);
}

Val* Fusion::zeroVal() {
if (!zero_val_) {
auto val = IrBuilder::createInContainer<Val>(this, 0L, DataType::Index);
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
zero_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
ir_container()->vals_up_.pop_back();
}
return zero_val_.get();
}

Val* Fusion::oneVal() {
if (!one_val_) {
auto val = IrBuilder::createInContainer<Val>(this, 1L, DataType::Index);
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
one_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
ir_container()->vals_up_.pop_back();
}
return one_val_.get();
}

Val* Fusion::falseVal() {
if (!false_val_) {
auto val = IrBuilder::createInContainer<Val>(this, false, DataType::Bool);
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
false_val_ =
std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
ir_container()->vals_up_.pop_back();
}
return false_val_.get();
}

Val* Fusion::trueVal() {
if (!true_val_) {
auto val = IrBuilder::createInContainer<Val>(this, true, DataType::Bool);
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
true_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
ir_container()->vals_up_.pop_back();
}
return true_val_.get();
}

NamedScalar* Fusion::magicZeroVal() {
if (!magic_zero_val_) {
auto val = IrBuilder::createInContainer<NamedScalar>(
this, kMagicZeroName, DataType::Index);
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
magic_zero_val_ = std::unique_ptr<NamedScalar>(
ir_container()->vals_up_.back().release()->as<NamedScalar>());
ir_container()->vals_up_.pop_back();
}
return magic_zero_val_.get();
}

Val* Fusion::zeroVal(DataType dtype) {
if (dtype == DataType::Index) {
return zeroVal();
} else if (isBooleanType(dtype)) {
return falseVal();
} else {
// NOTE: this does not cache values
return IrBuilder::createInContainer<Val>(this, 0L, dtype);
}
}

Val* Fusion::oneVal(DataType dtype) {
if (dtype == DataType::Index) {
return oneVal();
} else if (isBooleanType(dtype)) {
return trueVal();
} else {
// NOTE: this does not cache values
return IrBuilder::createInContainer<Val>(this, 1L, dtype);
}
}

Val* Fusion::metadataOf(Val* v) {
if (metadata_.count(v) == 0) {
auto metadata_val =
IrBuilder::createInContainer<Val>(this, metaDataTypeOf(v));
auto metadata_expr =
IrBuilder::createInContainer<GetMetaData>(this, metadata_val, v);
metadata_[v] = std::make_pair(metadata_val, metadata_expr);
}
return metadata_.at(v).first;
}

const std::vector<Val*>& Fusion::axioms() {
if (!axioms_) {
axioms_ = std::make_unique<std::vector<Val*>>();
axioms_->reserve(kParallelTypeThreads.size() * 3);
auto zero = zeroVal();
for (auto p : kParallelTypeThreads) {
auto pidx = NamedScalar::getParallelIndex(p);
auto pdim = NamedScalar::getParallelDim(p);
axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero));
axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero));
axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim));
}
}
return *axioms_;
}

void Fusion::assumePositive(Val* val) {
NVF_ERROR(inContainer(val));
// Lazy init axioms, then add the assumption
axioms();
axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal()));
}

void Fusion::assumeNonNegative(Val* val) {
NVF_ERROR(inContainer(val));
// Lazy init axioms, then add the assumption
axioms();
axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal()));
}

void Fusion::registerVal(Val* val) {
if (inContainer(val)) {
return;
Expand Down
66 changes: 25 additions & 41 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace nvfuser {
//! checks.

class Fusion;
class NamedScalar;
class TensorView;

class SegmentCandidateFinder;
Expand Down Expand Up @@ -549,55 +550,28 @@ class NVF_API Fusion : public PolymorphicBase {
return ir_container()->numExprs();
}

int64_t numVals(bool include_shortcuts) const noexcept {
// When include_shortcuts is true, count cached special vals (zeroVal, etc.)
// which live outside vals_up_ but inside vals_.
int64_t numVals(bool include_shortcuts = true) const noexcept {
return ir_container()->numVals(include_shortcuts);
}

// Shortcut values (frequently used constants)
Val* zeroVal() {
return ir_container()->zeroVal();
}

Val* oneVal() {
return ir_container()->oneVal();
}
Val* zeroVal();
Val* oneVal();
Val* falseVal();
Val* trueVal();
NamedScalar* magicZeroVal();
Val* zeroVal(DataType dtype);
Val* oneVal(DataType dtype);

Val* falseVal() {
return ir_container()->falseVal();
}

Val* trueVal() {
return ir_container()->trueVal();
}

NamedScalar* magicZeroVal() {
return ir_container()->magicZeroVal();
}

Val* zeroVal(DataType dtype) {
return ir_container()->zeroVal(dtype);
}

Val* oneVal(DataType dtype) {
return ir_container()->oneVal(dtype);
}

Val* metadataOf(Val* val) {
return ir_container()->metadataOf(val);
}
Val* metadataOf(Val* val);

// Axioms (CUDA programming assumptions)
const std::vector<Val*>& axioms() {
return ir_container()->axioms();
}

void assumePositive(Val* val) {
ir_container()->assumePositive(val);
}
const std::vector<Val*>& axioms();

void assumeNonNegative(Val* val) {
ir_container()->assumeNonNegative(val);
}
void assumePositive(Val* val);
void assumeNonNegative(Val* val);

// Statement removal
void removeStatementsCreatedAfter(
Expand Down Expand Up @@ -667,6 +641,16 @@ class NVF_API Fusion : public PolymorphicBase {

inline static const std::string exact_mappings_key = "exact_mappings";
std::unique_ptr<IrContainer> ir_container_;

std::unique_ptr<Val> zero_val_;
std::unique_ptr<Val> one_val_;
std::unique_ptr<Val> true_val_;
std::unique_ptr<Val> false_val_;
std::unique_ptr<NamedScalar> magic_zero_val_;

std::unique_ptr<std::vector<Val*>> axioms_;

std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
};

// Template implementations for Fusion::manage<T>() that use IrCloner
Expand Down
Loading