Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 221 additions & 2 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,13 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
grouped_inputs_outputs[1],
max_unroll_factor);
}

tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx);
if (std::getenv("USE_TMA") != nullptr) {
tparams->use_tma_load = true;
tparams->use_tma_store = true;
}
if (!tparams->use_tma_load) {
tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx);
}

if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
debug() << "\n===== Transpose Stats ========\n"
Expand Down Expand Up @@ -843,7 +848,221 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
return tparams;
}

void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) {
FusionGuard fg(fusion);

// Make sure we don't have global memory set on intermediate tensors from
// fusion segmentation
scheduler_utils::clearMemorySpace(fusion);

// maybe has_reduction for scheduling should be done on a per output tensor
// basis.
NVF_ERROR(
!ir_utils::hasAnyReductionOps(fusion),
"This scheduler only handles pointwise ops.");

// Cache inputs
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);

// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);

scheduler_utils::prepareForMemoryTypePromotion(fusion);

std::vector<TensorView*> input_tvs;
{
auto filtered_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
// Remove hanging tensor views
for (auto tv : filtered_tvs) {
if (tv->uses().empty()) {
continue;
}
input_tvs.push_back(tv);
}
}
auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(scheduler_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(scheduler_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
if (max_dims == 0) {
return;
}

scheduler_tools::TransposeDomainMap domain_map(fusion);
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
NVF_ERROR(grouped_inputs_outputs.size() >= 2);

/*
* We need something similar to `cacheFork` for input tensors in group 2. We
* need this because we will want to propagate to the entire DAG except group
* 2 and its cached inputs, so we need to make sure the DAG is still connected
* if we remove group and its cached inputs. For example
* t0
* |
* cache
* / \
* t1 t2
* if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will
* make it disconnected.
*/
std::vector<TensorView*> input_smem_tvs;
std::vector<TensorView*> output_smem_tvs;
TensorView* input_reference = nullptr;
std::unordered_set<TensorView*> group2_and_cached_inputs(
grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end());
for (auto tv : grouped_inputs_outputs[1]) {
if (tv->isFusionInput()) {
input_reference = tv;
auto existing_cache = ir_utils::consumerTvsOf(tv)[0];
if (ir_utils::consumerTvsOf(existing_cache).size() > 1) {
auto new_cache = tv->cacheAfter();
new_cache->setMemoryType(MemoryType::Shared);
input_smem_tvs.push_back(new_cache);
std::cout << "input_smem_tvs: " << new_cache->toString() << std::endl;
group2_and_cached_inputs.emplace(new_cache);
} else {
existing_cache->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
existing_cache->setMemoryType(MemoryType::Shared);
input_smem_tvs.push_back(existing_cache);
std::cout << "input_smem_tvs: " << existing_cache->toString()
<< std::endl;
group2_and_cached_inputs.emplace(existing_cache);
}
}
}
// set cached outputs of group 2 to shared memory
TensorView* output_reference = nullptr;
TensorView* output_reg_cache = nullptr;
for (const auto& [cached_output, output_idx] : cached_outputs) {
auto output = fusion->outputs()[output_idx]->as<TensorView>();
output_reference = output;
if (true || group2_and_cached_inputs.count(output) > 0) {
output->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
cached_output->setMemoryType(MemoryType::Shared);
output_smem_tvs.push_back(cached_output);
std::cout << "output_smem_tvs: " << cached_output->toString()
<< std::endl;
output_reg_cache = cached_output->cacheBefore();
}
}

TensorView* reference1 =
domain_map.findReferenceFor(grouped_inputs_outputs[0]);
TensorView* reference2 =
domain_map.findReferenceFor(grouped_inputs_outputs[1]);

std::cout << "reference1: " << reference1->toString() << std::endl;
std::cout << "reference2: " << reference2->toString() << std::endl;

// output: [I0, I1] -> [I0/tile_0 * I1/tile_1, tile_0, tile_1]
// smem -> register is vectorized
// register -> smem is unrolled
int64_t tile_0 = 32, tile_1 = 32, n_warps = 4, unroll_vect = 4;
int64_t tiles_per_block = 4;
// int64_t n_serial_loop = tile_0 / unroll_vect / n_warps;
output_reference->split(1, tile_1);
output_reference->split(0, tile_0);
output_reference->reorder({{-2, 0}});
output_reference->merge(0);
// [I0/tile_0 * I1/tile_1/tiles_per_block, tiles_per_block, tile_0, tile_1]
output_reference->split(0, tiles_per_block);
int64_t pos_bdimx = 0, tile0_pos = 2;
output_reference->axis(pos_bdimx)->parallelize(ParallelType::BIDx);
// output_reference->axis(pos_bdimx + 1)->parallelize(ParallelType::Unroll);
using Options =
scheduler_utils::BoundedDirectionalTransformPropagator::Options;
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_reference,
-1,
{input_reference},
Options{}.propagateParallelType());
// For fusion output, we just use TMA to store the entire tile back to global
// memory. There is no need to further schedule the output tensor.
output_reference->axis(tile0_pos)->parallelize(ParallelType::Bulk);
output_reference->axis(tile0_pos + 1)->parallelize(ParallelType::Bulk);

fusion->printMath();

for (auto output_smem_cache : output_smem_tvs) {
// [BIDx, 32', 32]
output_smem_cache->setAllocationDomain(
output_smem_cache->getLoopDomain(), true);
// [BDIx, tile_0, tile_1] -> [BDIx, tile_0/unroll_vect, unroll_vect, tile_1]
output_smem_cache->split(tile0_pos, unroll_vect);
// [BDIx, tile_0/unroll_vect/n_warps, n_warps, unroll_vect, tile_1]
output_smem_cache->split(tile0_pos, n_warps);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_smem_cache, -1, {input_reference});
output_smem_cache->axis(tile0_pos + 1)->parallelize(ParallelType::TIDy);
output_smem_cache->axis(tile0_pos + 3)->parallelize(ParallelType::TIDx);
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_smem_cache,
-1,
{input_smem_tvs.at(0)},
Options{}.propagateParallelType());
output_smem_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Unroll);
output_reg_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Vectorize);
output_reg_cache->setAllocationDomain(
output_reg_cache->getLoopDomain(), true);
}
for (auto input_smem_cache : input_smem_tvs) {
// Schedule the memory format for 128 byte swizzle
// After backward propagation and reorder:
// [BIDx, tile_1, tile_0/unroll_vect/n_warps, n_warps, unroll_vect]
// = [BIDx, 32, 2, 4, 4]
int64_t tile1_pos = 2;
input_smem_cache->reorder({{-1, 2}});

std::cout << "input_smem_cache after reorder: "
<< input_smem_cache->toString() << std::endl;

// Split tile_1 (32) by 8 to create first dimension of size 8
// [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4]
input_smem_cache->split(tile1_pos, 8);

// Merge the 2×4 to create second dimension of size 8
// [BIDx, tilesPerCTA, 4, 8, 2, 4, 4] -> [BIDx, tilesPerCTA, 4, 8, 8, 4]
input_smem_cache->merge(tile1_pos + 2, tile1_pos + 3);

std::cout << "input_smem_cache before swizzle: "
<< input_smem_cache->toString() << std::endl;

// Swizzle the two 8's at positions 2 and 3
// [BIDx, tilesPerCTA, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3
input_smem_cache->swizzle(SwizzleType::XOR, tile1_pos + 1, tile1_pos + 2);

// Set allocation domain to match the swizzled layout
input_smem_cache->setAllocationDomain(
input_smem_cache->getLoopDomain(), true);

// Parallelize all dimensions as Bulk for TMA
input_smem_cache->axis(tile1_pos)->parallelize(ParallelType::Bulk);
input_smem_cache->axis(tile1_pos + 1)->parallelize(ParallelType::Bulk);
input_smem_cache->axis(tile1_pos + 2)->parallelize(ParallelType::Bulk);
input_smem_cache->axis(tile1_pos + 3)->parallelize(ParallelType::Bulk);
// [BIDx, Bulk, Bulk, Bulk, Bulk]
}

inlineMost();

fusion->print();
}

void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
if (tparams->use_tma_store || tparams->use_tma_load) {
return scheduleTransposeTMA(fusion, tparams);
}
FusionGuard fg(fusion);

// Make sure we don't have global memory set on intermediate tensors from
Expand Down
20 changes: 18 additions & 2 deletions csrc/scheduler/transpose_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ class TransposeParams : public HeuristicParams {
// Tile size for the inner most dim of tensors in the second group
int64_t tile_size2 = getDefaultTileSize();

// Use Hopper+ TMA (cp.async.bulk.tensor) for group2 cached-input loads
// (global -> shared). This is intentionally conservative and only affects
// the shared-memory staging tensors used for the thread-binding swap.
bool use_tma_load = false;

// Use Hopper+ TMA (cp.async.bulk.tensor) for group1 cached-output stores
// (shared -> global). When enabled with use_tma_load, implements the full
// bank-conflict-free transpose pattern with 128-byte swizzle.
bool use_tma_store = false;

using HeuristicParams::HeuristicParams;

// Warning: Does not check launch parameters!
Expand All @@ -65,7 +75,9 @@ class TransposeParams : public HeuristicParams {
other->dims_merged_with_2 == dims_merged_with_2 &&
other->vectorize_factor1 == vectorize_factor1 &&
other->vectorize_factor2 == vectorize_factor2 &&
other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2;
other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2 &&
other->use_tma_load == use_tma_load &&
other->use_tma_store == use_tma_store;
return attr_equal;
}

Expand All @@ -76,6 +88,8 @@ class TransposeParams : public HeuristicParams {
<< " BlckX: " << lparams.bdimx() << "\n";
ss << " input tile size: " << tile_size1 << "\n";
ss << " output tile size: " << tile_size2 << "\n";
ss << " use_tma_load: " << use_tma_load << "\n";
ss << " use_tma_store: " << use_tma_store << "\n";
int64_t elements_per_tile = tile_size1 * tile_size2;
ss << " elements per tile: " << elements_per_tile << "\n";
int64_t elements_per_thread = elements_per_tile / lparams.bdimx();
Expand Down Expand Up @@ -146,7 +160,9 @@ class TransposeParams : public HeuristicParams {
vectorize_factor1,
vectorize_factor2,
tile_size1,
tile_size2);
tile_size2,
use_tma_load,
use_tma_store);
}

std::unique_ptr<HeuristicParams> clone() const override {
Expand Down
Loading
Loading