Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
// update the parent backpointers in those containers to point to their new
// owners
if (a.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
a.ir_container()->parent_ = &a;
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
Expand All @@ -126,8 +124,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
}
}
if (b.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
b.ir_container()->parent_ = &b;
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
Expand Down Expand Up @@ -161,7 +157,8 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());
auto ir_cloner =
IrContainer::copy(from->ir_container(), to->ir_container(), to);

// Remap cached special val pointers through the cloner
if (from->zero_val_) {
Expand Down Expand Up @@ -254,8 +251,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

// Default constructor
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
ir_container_->parent_ = this;
Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
ir_container_->addFusion(this);
}

// Copy constructor
Expand Down Expand Up @@ -287,6 +284,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept {

Fusion::~Fusion() {
clear();
if (ir_container_) {
ir_container_->removeFusion(this);
}
}

void Fusion::clear() noexcept {
Expand Down Expand Up @@ -350,9 +350,7 @@ void Fusion::removeExpr(Expr* expr) {
auto expr_in_deque = std::find_if(
c->exprs_up_.begin(),
c->exprs_up_.end(),
[expr](std::unique_ptr<Expr>& expr_up) {
return expr_up.get() == expr;
});
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
NVF_ERROR(
expr_in_deque != c->exprs_up_.end(),
"Wanted to remove an expression but its unique ptr is missing.");
Expand Down
7 changes: 5 additions & 2 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ class NVF_API Fusion : public PolymorphicBase {
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;

protected:
// Direct access to underlying container
IrContainer* ir_container() {
NVF_ERROR(
ir_container_.get() != nullptr,
Expand All @@ -163,6 +162,10 @@ class NVF_API Fusion : public PolymorphicBase {
return ir_container_.get();
}

std::shared_ptr<IrContainer> ir_container_ptr() const {
return ir_container_;
}

public:
// Registration (public API with passkey)
virtual void registerStmt(IrBuilderPasskey, Statement* stmt) {
Expand Down Expand Up @@ -635,7 +638,7 @@ class NVF_API Fusion : public PolymorphicBase {
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;

inline static const std::string exact_mappings_key = "exact_mappings";
std::unique_ptr<IrContainer> ir_container_;
std::shared_ptr<IrContainer> ir_container_;

Val* zero_val_ = nullptr;
Val* one_val_ = nullptr;
Expand Down
36 changes: 31 additions & 5 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {

std::swap(a.val_type_name_map_, b.val_type_name_map_);
std::swap(a.expr_name_counter_, b.expr_name_counter_);

std::swap(a.parent_, b.parent_);
}

IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
IrCloner IrContainer::copy(
const IrContainer* from,
IrContainer* to,
Fusion* dest_fusion) {
to->clear();

IrCloner ir_cloner(to->parent());
IrCloner ir_cloner(dest_fusion);

// Copy values in deterministic order
for (auto val : from->deterministic_vals()) {
Expand Down Expand Up @@ -138,7 +139,7 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
}

NVF_ERROR(
const_stmt->container() == this->parent(),
sharing_fusions_.count(const_stmt->container()) > 0,
"Container claims to own stmt, but stmt disagrees.");

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
Expand All @@ -157,4 +158,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
return true;
}

void IrContainer::addFusion(Fusion* fusion) {
sharing_fusions_.insert(fusion);
}

void IrContainer::removeFusion(Fusion* fusion) {
sharing_fusions_.erase(fusion);
}

void IrContainer::transferFusion(Fusion* from, Fusion* to) {
sharing_fusions_.erase(from);
sharing_fusions_.insert(to);
}

size_t IrContainer::sharingCount() const {
return sharing_fusions_.size();
}

bool IrContainer::hasMultipleFusions() const {
return sharing_fusions_.size() > 1;
}

const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
return sharing_fusions_;
}

} // namespace nvfuser
20 changes: 11 additions & 9 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class IrContainer {
}

protected:
static IrCloner copy(const IrContainer* from, IrContainer* to);
static IrCloner copy(
const IrContainer* from,
IrContainer* to,
Fusion* dest_fusion);

static void swap(IrContainer& a, IrContainer& b) noexcept;

Expand Down Expand Up @@ -127,16 +130,15 @@ class IrContainer {
StmtNameType expr_name_counter_ = 0;

public:
Fusion* parent() const {
NVF_ERROR(
parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.")
return parent_;
}
void addFusion(Fusion* fusion);
void removeFusion(Fusion* fusion);
void transferFusion(Fusion* from, Fusion* to);
size_t sharingCount() const;
bool hasMultipleFusions() const;
const std::unordered_set<Fusion*>& sharingFusions() const;

private:
// Parent Fusion that owns this container (for pure composition pattern)
// Used by Statement::fusion() to navigate back to owning Fusion
Fusion* parent_ = nullptr;
std::unordered_set<Fusion*> sharing_fusions_;
};

} // namespace nvfuser
9 changes: 7 additions & 2 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

namespace nvfuser {

// TODO: Remove when std::shared_mutex is added to IrContainer.
constexpr bool kPhase2DisableParallelCompile = true;

namespace {
// Replace CUDA tensor with Meta tensor because storing tensors can cause
// out-of-memory issues. Other arguments are returned as-is.
Expand Down Expand Up @@ -454,7 +457,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
try {
for (const auto& [group_to_run, group_runtime_inputs] :
zip(runtime_workspace_.group_run_order, all_runtime_inputs)) {
if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) {
if (num_groups == 1 || kPhase2DisableParallelCompile ||
isOptionDisabled(DisableOption::ParallelCompile)) {
compileKernel(group_runtime_inputs, group_to_run);
} else {
// launch compileKernel thread here
Expand Down Expand Up @@ -488,7 +492,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
throw;
}

if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) {
if (num_groups != 1 && !kPhase2DisableParallelCompile &&
!isOptionDisabled(DisableOption::ParallelCompile)) {
// Wait until all segments finish compiling
getThreadPool()->waitWorkComplete();
NVF_ERROR(
Expand Down
Loading