From 0a69cd13488940de3867c574b37be120af273246 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 28 Jan 2026 10:31:17 -0800 Subject: [PATCH 01/11] tma load --- csrc/scheduler/transpose.cpp | 140 ++++++++++++++++++++++++--- csrc/scheduler/transpose_heuristic.h | 12 ++- tests/cpp/test_transpose.cpp | 22 +++++ 3 files changed, 159 insertions(+), 15 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 766213ca484..f8a47416404 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -18,6 +18,8 @@ #include #include +#include + namespace nvfuser { bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) { @@ -663,6 +665,16 @@ std::unique_ptr getTransposeHeuristics( tparams->tag = "Transpose heuristics"; tparams->cparams.index_type = index_type; + // Enable TMA (cp.async.bulk.tensor) only on Hopper+. + // Keep conservative constraints aligned with current nvFuser TMA support. + // NOTE: This only affects group2 cached-input loads (gmem->smem) used for the + // transpose thread-binding swap. + constexpr int64_t kMaxElementsPerTmaTileDim = 256; + const auto* props = at::cuda::getCurrentDeviceProperties(); + tparams->use_tma_load = (props->major >= 9) && + (tparams->tile_size1 <= kMaxElementsPerTmaTileDim) && + (tparams->tile_size2 <= kMaxElementsPerTmaTileDim); + // Expand inner-most dims to virtual inner-most dims so that the inner-most // dims has at least tile_size elements // See note [Supporting small transpose dimensions] @@ -720,6 +732,10 @@ std::unique_ptr getTransposeHeuristics( scan_max_dtype_size(fusion->inputs()); scan_max_dtype_size(fusion->outputs()); + // //set tile size + // tparams->tile_size1 = 256; + // tparams->tile_size2 = 256; + auto max_unroll_factor = ceilDiv( // Available unrolling based on size of data type kSixteen / max_io_dtype_size, @@ -910,6 +926,10 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { */ std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + std::vector smem_tvs; + // Subset of shared-memory staging TVs that are gmem->smem cached-input loads + // and will be scheduled with TMA (CpAsyncBulkTensorTile). + std::vector tma_smem_load_tvs; for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; @@ -917,9 +937,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(new_cache); + smem_tvs.push_back(new_cache); } else { existing_cache->setMemoryType(MemoryType::Shared); group2_and_cached_inputs.emplace(existing_cache); + smem_tvs.push_back(existing_cache); } } } @@ -928,6 +950,26 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto output = fusion->outputs()[output_idx]->as(); if (group2_and_cached_inputs.count(output) > 0) { cached_output->setMemoryType(MemoryType::Shared); + smem_tvs.push_back(cached_output); + } + } + + // Configure TMA loads for the gmem->smem cached-input TVs. + // Important: TMA needs the two tiling dimensions preserved as separate Bulk + // axes (i.e., do not merge tile1*tile2 like the normal group2 compute path). + if (tparams->use_tma_load) { + for (auto smem_tv : smem_tvs) { + auto ldst = dynamic_cast(smem_tv->definition()); + if (ldst == nullptr) { + continue; + } + // Only enable for gmem->smem loads + auto in_tv = dynamic_cast(ldst->in()); + if (in_tv == nullptr || !in_tv->isFusionInput()) { + continue; + } + ldst->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); + tma_smem_load_tvs.push_back(smem_tv); } } @@ -949,6 +991,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto inner_most_id1 = scheduler_utils::innerMostAllocDim(reference1); auto inner_most_id2 = scheduler_utils::innerMostAllocDim(reference2); + std::cout << "ref1 " << reference1->toString() << std::endl; + std::cout << "ref2 " << reference2->toString() << std::endl; + ////////////////////////////////////////// // Step 1: Make virtual inner most dims // ////////////////////////////////////////// @@ -1026,18 +1071,23 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { rhs_i = lhs_i; } - reference1->split(rhs_i, 1); - // [r.., merged_dim, 1, tile1, tile2] + // reference1->split(rhs_i, 1); + // // [r.., merged_dim, 1, tile1, tile2] - // parallelize non-tile dimensions - reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); + // // parallelize non-tile dimensions + // reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); reference1->axis(rhs_i)->parallelize(ParallelType::BIDx); // [r.., BIDx, Unswitch, tile1, tile2] - // Propagate transformations so far to the entire DAG - TransformPropagator propagator(reference1); - MaxLogicalDomainInfoSpanningTree entire_dag(reference1); - entire_dag.traverse(&propagator); + // Propagate transformations so far (including outer-dim merges and BIDx / + // Unswitch) to the entire DAG. This ensures TMA staging TVs have the same + // untiled-domain schedule as other TVs. TMA staging TVs are still excluded + // later from the group2 compute scheduling that merges tile1*tile2. + { + TransformPropagator propagator(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); + } // We may be propagating a reshape during the above transformation. // T0[i0 * i1] -> View -> T1[i0 i1] (Root=[i0*i1]) @@ -1061,6 +1111,40 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { /*propagate_padding=*/true, /*parallelize_inputs_on_did=*/true); + // Apply TMA-specific 2D Bulk parallelization for the cached-input smem TVs. + // Keep it limited to the tile dims only, as the rest of the schedule is + // shared with non-TMA tensors. + if (!tma_smem_load_tvs.empty()) { + // Identify the two tiling axes by exact-ID mapping against reference1's + // tiling axes. This is robust even when tile_size1 == tile_size2. + ComputeAtMap ca_map(fusion); + IterDomain* ref_tile1 = reference1->axis(-2); + IterDomain* ref_tile2 = reference1->axis(-1); + auto bulkParallelizeTileDims = [&](TensorView* tv) { + IterDomain* tile1_id = nullptr; + IterDomain* tile2_id = nullptr; + for (auto id : tv->getLoopDomain()) { + if (tile1_id == nullptr && + ca_map.areMapped(id, ref_tile1, IdMappingMode::EXACT)) { + tile1_id = id; + } + if (tile2_id == nullptr && + ca_map.areMapped(id, ref_tile2, IdMappingMode::EXACT)) { + tile2_id = id; + } + } + NVF_ERROR( + tile1_id != nullptr && tile2_id != nullptr, + "Failed to identify tiling axes for TMA smem TV: ", + tv->toString()); + tile1_id->parallelize(ParallelType::Bulk); + tile2_id->parallelize(ParallelType::Bulk); + }; + for (auto tv : tma_smem_load_tvs) { + bulkParallelizeTileDims(tv); + } + } + // For a transpose scheduling, all we need is to bind threadIdx.x differently // for inputs and outputs. This swap of binding could happen at any tensor on // the path from input to output, especially, it does not have to be in the @@ -1134,6 +1218,22 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto all_tvs_except1 = ir_utils::allTvsExcept( fusion, {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); + // Keep TMA smem staging tensors out of the group2 compute scheduling, as it + // merges tile1*tile2 and would destroy the 2D Bulk structure required by + // CpAsyncBulkTensorTile. + if (!tma_smem_load_tvs.empty()) { + all_tvs_except1.erase( + std::remove_if( + all_tvs_except1.begin(), + all_tvs_except1.end(), + [&](TensorView* tv) { + return std::find( + tma_smem_load_tvs.begin(), + tma_smem_load_tvs.end(), + tv) != tma_smem_load_tvs.end(); + }), + all_tvs_except1.end()); + } SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); @@ -1153,16 +1253,30 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { ComputeAtMap ca_map(fusion); + // Exclude TMA staging tensors; they keep Bulk-parallel tile axes. + std::vector group2_sched_tvs( + group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); + if (!tma_smem_load_tvs.empty()) { + group2_sched_tvs.erase( + std::remove_if( + group2_sched_tvs.begin(), + group2_sched_tvs.end(), + [&](TensorView* tv) { + return std::find( + tma_smem_load_tvs.begin(), + tma_smem_load_tvs.end(), + tv) != tma_smem_load_tvs.end(); + }), + group2_sched_tvs.end()); + } scheduler_utils::parallelizeAllLike( - reference2, - {group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()}, - {ParallelType::TIDx}); + reference2, group2_sched_tvs, {ParallelType::TIDx}); // Only vectorize the axes that exactly maps to the vectorized axes // on reference as support for permissively mapped axes are not // yet clearly defined. std::vector vectorized_group2_cached_inputs; - for (auto gin : group2_and_cached_inputs) { + for (auto gin : group2_sched_tvs) { if (std::any_of( gin->getLoopDomain().begin(), gin->getLoopDomain().end(), @@ -1184,7 +1298,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // on reference as support for permissively mapped axes are not // yet clearly defined. std::vector unrolled_group2_cached_inputs; - for (auto gin : group2_and_cached_inputs) { + for (auto gin : group2_sched_tvs) { if (std::any_of( gin->getLoopDomain().begin(), gin->getLoopDomain().end(), diff --git a/csrc/scheduler/transpose_heuristic.h b/csrc/scheduler/transpose_heuristic.h index f4de63bdd60..bbc45b9757f 100644 --- a/csrc/scheduler/transpose_heuristic.h +++ b/csrc/scheduler/transpose_heuristic.h @@ -51,6 +51,11 @@ class TransposeParams : public HeuristicParams { // Tile size for the inner most dim of tensors in the second group int64_t tile_size2 = getDefaultTileSize(); + // Use Hopper+ TMA (cp.async.bulk.tensor) for group2 cached-input loads + // (global -> shared). This is intentionally conservative and only affects + // the shared-memory staging tensors used for the thread-binding swap. + bool use_tma_load = false; + using HeuristicParams::HeuristicParams; // Warning: Does not check launch parameters! @@ -65,7 +70,8 @@ class TransposeParams : public HeuristicParams { other->dims_merged_with_2 == dims_merged_with_2 && other->vectorize_factor1 == vectorize_factor1 && other->vectorize_factor2 == vectorize_factor2 && - other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2; + other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2 && + other->use_tma_load == use_tma_load; return attr_equal; } @@ -76,6 +82,7 @@ class TransposeParams : public HeuristicParams { << " BlckX: " << lparams.bdimx() << "\n"; ss << " input tile size: " << tile_size1 << "\n"; ss << " output tile size: " << tile_size2 << "\n"; + ss << " use_tma_load: " << use_tma_load << "\n"; int64_t elements_per_tile = tile_size1 * tile_size2; ss << " elements per tile: " << elements_per_tile << "\n"; int64_t elements_per_thread = elements_per_tile / lparams.bdimx(); @@ -146,7 +153,8 @@ class TransposeParams : public HeuristicParams { vectorize_factor1, vectorize_factor2, tile_size1, - tile_size2); + tile_size2, + use_tma_load); } std::unique_ptr clone() const override { diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 3d0bb16b87a..d90deb83f81 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1409,4 +1409,26 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); } +TEST_F(TransposeTest, Tma) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::BFloat16; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = add(tv1, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + auto tv5 = castOp(dtype, tv4); + fusion.addOutput(tv5); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0}); + testValidate(executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); +} } // namespace nvfuser From fce7b02cac97be27b1628eea5fd1fbbfd62dc000 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 29 Jan 2026 14:28:53 -0800 Subject: [PATCH 02/11] add tma load/store --- csrc/scheduler/transpose.cpp | 322 ++++++++++++++++----------- csrc/scheduler/transpose_heuristic.h | 12 +- tests/cpp/test_transpose.cpp | 97 ++++++++ tests/cpp/test_tutorial.cpp | 1 + 4 files changed, 305 insertions(+), 127 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index f8a47416404..c972e219ec7 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -18,8 +18,6 @@ #include #include -#include - namespace nvfuser { bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) { @@ -665,16 +663,6 @@ std::unique_ptr getTransposeHeuristics( tparams->tag = "Transpose heuristics"; tparams->cparams.index_type = index_type; - // Enable TMA (cp.async.bulk.tensor) only on Hopper+. - // Keep conservative constraints aligned with current nvFuser TMA support. - // NOTE: This only affects group2 cached-input loads (gmem->smem) used for the - // transpose thread-binding swap. - constexpr int64_t kMaxElementsPerTmaTileDim = 256; - const auto* props = at::cuda::getCurrentDeviceProperties(); - tparams->use_tma_load = (props->major >= 9) && - (tparams->tile_size1 <= kMaxElementsPerTmaTileDim) && - (tparams->tile_size2 <= kMaxElementsPerTmaTileDim); - // Expand inner-most dims to virtual inner-most dims so that the inner-most // dims has at least tile_size elements // See note [Supporting small transpose dimensions] @@ -732,10 +720,6 @@ std::unique_ptr getTransposeHeuristics( scan_max_dtype_size(fusion->inputs()); scan_max_dtype_size(fusion->outputs()); - // //set tile size - // tparams->tile_size1 = 256; - // tparams->tile_size2 = 256; - auto max_unroll_factor = ceilDiv( // Available unrolling based on size of data type kSixteen / max_io_dtype_size, @@ -827,8 +811,13 @@ std::unique_ptr getTransposeHeuristics( grouped_inputs_outputs[1], max_unroll_factor); } - - tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx); + if (std::getenv("USE_TMA") != nullptr) { + tparams->use_tma_load = true; + tparams->use_tma_store = true; + } + if (!tparams->use_tma_load) { + tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx); + } if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << "\n===== Transpose Stats ========\n" @@ -859,7 +848,7 @@ std::unique_ptr getTransposeHeuristics( return tparams; } -void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { +void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { FusionGuard fg(fusion); // Make sure we don't have global memory set on intermediate tensors from @@ -924,52 +913,207 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { * if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will * make it disconnected. */ + std::vector input_smem_tvs; + std::vector output_smem_tvs; + TensorView* input_reference = nullptr; std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); - std::vector smem_tvs; - // Subset of shared-memory staging TVs that are gmem->smem cached-input loads - // and will be scheduled with TMA (CpAsyncBulkTensorTile). - std::vector tma_smem_load_tvs; for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { + input_reference = tv; auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; if (ir_utils::consumerTvsOf(existing_cache).size() > 1) { auto new_cache = tv->cacheAfter(); new_cache->setMemoryType(MemoryType::Shared); + input_smem_tvs.push_back(new_cache); + std::cout << "input_smem_tvs: " << new_cache->toString() << std::endl; group2_and_cached_inputs.emplace(new_cache); - smem_tvs.push_back(new_cache); } else { + existing_cache->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); existing_cache->setMemoryType(MemoryType::Shared); + input_smem_tvs.push_back(existing_cache); + std::cout << "input_smem_tvs: " << existing_cache->toString() + << std::endl; group2_and_cached_inputs.emplace(existing_cache); - smem_tvs.push_back(existing_cache); } } } // set cached outputs of group 2 to shared memory + TensorView* output_reference = nullptr; + TensorView* output_reg_cache = nullptr; for (const auto& [cached_output, output_idx] : cached_outputs) { auto output = fusion->outputs()[output_idx]->as(); - if (group2_and_cached_inputs.count(output) > 0) { + output_reference = output; + if (true || group2_and_cached_inputs.count(output) > 0) { + output->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); cached_output->setMemoryType(MemoryType::Shared); - smem_tvs.push_back(cached_output); + output_smem_tvs.push_back(cached_output); + std::cout << "output_smem_tvs: " << cached_output->toString() + << std::endl; + output_reg_cache = cached_output->cacheBefore(); } } - // Configure TMA loads for the gmem->smem cached-input TVs. - // Important: TMA needs the two tiling dimensions preserved as separate Bulk - // axes (i.e., do not merge tile1*tile2 like the normal group2 compute path). - if (tparams->use_tma_load) { - for (auto smem_tv : smem_tvs) { - auto ldst = dynamic_cast(smem_tv->definition()); - if (ldst == nullptr) { + TensorView* reference1 = + domain_map.findReferenceFor(grouped_inputs_outputs[0]); + TensorView* reference2 = + domain_map.findReferenceFor(grouped_inputs_outputs[1]); + + std::cout << "reference1: " << reference1->toString() << std::endl; + std::cout << "reference2: " << reference2->toString() << std::endl; + fusion->print(); + + // scheduler reference1 + output_reference->split(1, 32); + output_reference->split(0, 32); + output_reference->reorder({{-2, 0}}); + output_reference->merge(0); + // [I0/32 * I1/32', 32', 32] + output_reference->axis(0)->parallelize(ParallelType::BIDx); + using Options = + scheduler_utils::BoundedDirectionalTransformPropagator::Options; + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_reference, + -1, + {input_reference}, + Options{}.propagateParallelType()); + // For fusion output, we just use TMA to store the entire tile back to global + // memory. There is no need to further schedule the output tensor. + output_reference->axis(1)->parallelize(ParallelType::Bulk); + output_reference->axis(2)->parallelize(ParallelType::Bulk); + + for (auto output_smem_cache : output_smem_tvs) { + // [BIDx, 32', 32] + output_smem_cache->setAllocationDomain( + output_smem_cache->getLoopDomain(), true); + output_smem_cache->split(1, 4); + // [BIDx, 8', 4', 32] + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_smem_cache, -1, {input_reference}); + output_smem_cache->merge(1, 3); + // [BIDx, 256, 4'] + output_smem_cache->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_smem_cache, + -1, + {input_smem_tvs.at(0)}, + Options{}.propagateParallelType()); + output_smem_cache->axis(2)->parallelize(ParallelType::Unroll); + output_reg_cache->axis(2)->parallelize(ParallelType::Vectorize); + output_reg_cache->setAllocationDomain( + output_reg_cache->getLoopDomain(), true); + } + for (auto input_smem_cache : input_smem_tvs) { + // Schedule the memory format for 128 byte swizzle + // [BIDx, 8', 4', 32] + input_smem_cache->reorder({{-1, 1}}); + // [BIDx, 32, 8', 4'] + input_smem_cache->split(1, 8); + // [BIDx, 4, 8, 8', 4'] + input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); + // [BIDx, 4, 8, 8', 4'] + input_smem_cache->setAllocationDomain( + input_smem_cache->getLoopDomain(), true); + input_smem_cache->axis(1)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(2)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(3)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(4)->parallelize(ParallelType::Bulk); + // [BIDx, Bulk, Bulk, Bulk, Bulk] + } + + fusion->print(); +} + +void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { + if (tparams->use_tma_store || tparams->use_tma_load) { + return scheduleTransposeTMA(fusion, tparams); + } + FusionGuard fg(fusion); + + // Make sure we don't have global memory set on intermediate tensors from + // fusion segmentation + scheduler_utils::clearMemorySpace(fusion); + + // maybe has_reduction for scheduling should be done on a per output tensor + // basis. + NVF_ERROR( + !ir_utils::hasAnyReductionOps(fusion), + "This scheduler only handles pointwise ops."); + + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + + scheduler_utils::prepareForMemoryTypePromotion(fusion); + + std::vector input_tvs; + { + auto filtered_tvs = ir_utils::filterByType(fusion->inputs()); + // Remove hanging tensor views + for (auto tv : filtered_tvs) { + if (tv->uses().empty()) { continue; } - // Only enable for gmem->smem loads - auto in_tv = dynamic_cast(ldst->in()); - if (in_tv == nullptr || !in_tv->isFusionInput()) { - continue; + input_tvs.push_back(tv); + } + } + auto output_tvs = ir_utils::filterByType(fusion->outputs()); + + int64_t max_dims = 0; + for (auto inp : input_tvs) { + max_dims = std::max(scheduler_utils::nLogicalDims(inp), max_dims); + } + + for (auto out : output_tvs) { + max_dims = std::max(scheduler_utils::nLogicalDims(out), max_dims); + } + + // If everything is zero dim tensors, just return. + if (max_dims == 0) { + return; + } + + scheduler_tools::TransposeDomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + NVF_ERROR(grouped_inputs_outputs.size() >= 2); + + /* + * We need something similar to `cacheFork` for input tensors in group 2. We + * need this because we will want to propagate to the entire DAG except group + * 2 and its cached inputs, so we need to make sure the DAG is still connected + * if we remove group and its cached inputs. For example + * t0 + * | + * cache + * / \ + * t1 t2 + * if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will + * make it disconnected. + */ + std::unordered_set group2_and_cached_inputs( + grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + for (auto tv : grouped_inputs_outputs[1]) { + if (tv->isFusionInput()) { + auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; + if (ir_utils::consumerTvsOf(existing_cache).size() > 1) { + auto new_cache = tv->cacheAfter(); + new_cache->setMemoryType(MemoryType::Shared); + group2_and_cached_inputs.emplace(new_cache); + } else { + existing_cache->setMemoryType(MemoryType::Shared); + group2_and_cached_inputs.emplace(existing_cache); } - ldst->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); - tma_smem_load_tvs.push_back(smem_tv); + } + } + // set cached outputs of group 2 to shared memory + for (const auto& [cached_output, output_idx] : cached_outputs) { + auto output = fusion->outputs()[output_idx]->as(); + if (group2_and_cached_inputs.count(output) > 0) { + cached_output->setMemoryType(MemoryType::Shared); } } @@ -991,9 +1135,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto inner_most_id1 = scheduler_utils::innerMostAllocDim(reference1); auto inner_most_id2 = scheduler_utils::innerMostAllocDim(reference2); - std::cout << "ref1 " << reference1->toString() << std::endl; - std::cout << "ref2 " << reference2->toString() << std::endl; - ////////////////////////////////////////// // Step 1: Make virtual inner most dims // ////////////////////////////////////////// @@ -1071,23 +1212,18 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { rhs_i = lhs_i; } - // reference1->split(rhs_i, 1); - // // [r.., merged_dim, 1, tile1, tile2] + reference1->split(rhs_i, 1); + // [r.., merged_dim, 1, tile1, tile2] - // // parallelize non-tile dimensions - // reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); + // parallelize non-tile dimensions + reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); reference1->axis(rhs_i)->parallelize(ParallelType::BIDx); // [r.., BIDx, Unswitch, tile1, tile2] - // Propagate transformations so far (including outer-dim merges and BIDx / - // Unswitch) to the entire DAG. This ensures TMA staging TVs have the same - // untiled-domain schedule as other TVs. TMA staging TVs are still excluded - // later from the group2 compute scheduling that merges tile1*tile2. - { - TransformPropagator propagator(reference1); - MaxLogicalDomainInfoSpanningTree entire_dag(reference1); - entire_dag.traverse(&propagator); - } + // Propagate transformations so far to the entire DAG + TransformPropagator propagator(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); // We may be propagating a reshape during the above transformation. // T0[i0 * i1] -> View -> T1[i0 i1] (Root=[i0*i1]) @@ -1111,40 +1247,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { /*propagate_padding=*/true, /*parallelize_inputs_on_did=*/true); - // Apply TMA-specific 2D Bulk parallelization for the cached-input smem TVs. - // Keep it limited to the tile dims only, as the rest of the schedule is - // shared with non-TMA tensors. - if (!tma_smem_load_tvs.empty()) { - // Identify the two tiling axes by exact-ID mapping against reference1's - // tiling axes. This is robust even when tile_size1 == tile_size2. - ComputeAtMap ca_map(fusion); - IterDomain* ref_tile1 = reference1->axis(-2); - IterDomain* ref_tile2 = reference1->axis(-1); - auto bulkParallelizeTileDims = [&](TensorView* tv) { - IterDomain* tile1_id = nullptr; - IterDomain* tile2_id = nullptr; - for (auto id : tv->getLoopDomain()) { - if (tile1_id == nullptr && - ca_map.areMapped(id, ref_tile1, IdMappingMode::EXACT)) { - tile1_id = id; - } - if (tile2_id == nullptr && - ca_map.areMapped(id, ref_tile2, IdMappingMode::EXACT)) { - tile2_id = id; - } - } - NVF_ERROR( - tile1_id != nullptr && tile2_id != nullptr, - "Failed to identify tiling axes for TMA smem TV: ", - tv->toString()); - tile1_id->parallelize(ParallelType::Bulk); - tile2_id->parallelize(ParallelType::Bulk); - }; - for (auto tv : tma_smem_load_tvs) { - bulkParallelizeTileDims(tv); - } - } - // For a transpose scheduling, all we need is to bind threadIdx.x differently // for inputs and outputs. This swap of binding could happen at any tensor on // the path from input to output, especially, it does not have to be in the @@ -1218,22 +1320,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { auto all_tvs_except1 = ir_utils::allTvsExcept( fusion, {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); - // Keep TMA smem staging tensors out of the group2 compute scheduling, as it - // merges tile1*tile2 and would destroy the 2D Bulk structure required by - // CpAsyncBulkTensorTile. - if (!tma_smem_load_tvs.empty()) { - all_tvs_except1.erase( - std::remove_if( - all_tvs_except1.begin(), - all_tvs_except1.end(), - [&](TensorView* tv) { - return std::find( - tma_smem_load_tvs.begin(), - tma_smem_load_tvs.end(), - tv) != tma_smem_load_tvs.end(); - }), - all_tvs_except1.end()); - } SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); @@ -1253,30 +1339,16 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { ComputeAtMap ca_map(fusion); - // Exclude TMA staging tensors; they keep Bulk-parallel tile axes. - std::vector group2_sched_tvs( - group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()); - if (!tma_smem_load_tvs.empty()) { - group2_sched_tvs.erase( - std::remove_if( - group2_sched_tvs.begin(), - group2_sched_tvs.end(), - [&](TensorView* tv) { - return std::find( - tma_smem_load_tvs.begin(), - tma_smem_load_tvs.end(), - tv) != tma_smem_load_tvs.end(); - }), - group2_sched_tvs.end()); - } scheduler_utils::parallelizeAllLike( - reference2, group2_sched_tvs, {ParallelType::TIDx}); + reference2, + {group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()}, + {ParallelType::TIDx}); // Only vectorize the axes that exactly maps to the vectorized axes // on reference as support for permissively mapped axes are not // yet clearly defined. std::vector vectorized_group2_cached_inputs; - for (auto gin : group2_sched_tvs) { + for (auto gin : group2_and_cached_inputs) { if (std::any_of( gin->getLoopDomain().begin(), gin->getLoopDomain().end(), @@ -1298,7 +1370,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { // on reference as support for permissively mapped axes are not // yet clearly defined. std::vector unrolled_group2_cached_inputs; - for (auto gin : group2_sched_tvs) { + for (auto gin : group2_and_cached_inputs) { if (std::any_of( gin->getLoopDomain().begin(), gin->getLoopDomain().end(), diff --git a/csrc/scheduler/transpose_heuristic.h b/csrc/scheduler/transpose_heuristic.h index bbc45b9757f..29fc49d62cf 100644 --- a/csrc/scheduler/transpose_heuristic.h +++ b/csrc/scheduler/transpose_heuristic.h @@ -56,6 +56,11 @@ class TransposeParams : public HeuristicParams { // the shared-memory staging tensors used for the thread-binding swap. bool use_tma_load = false; + // Use Hopper+ TMA (cp.async.bulk.tensor) for group1 cached-output stores + // (shared -> global). When enabled with use_tma_load, implements the full + // bank-conflict-free transpose pattern with 128-byte swizzle. + bool use_tma_store = false; + using HeuristicParams::HeuristicParams; // Warning: Does not check launch parameters! @@ -71,7 +76,8 @@ class TransposeParams : public HeuristicParams { other->vectorize_factor1 == vectorize_factor1 && other->vectorize_factor2 == vectorize_factor2 && other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2 && - other->use_tma_load == use_tma_load; + other->use_tma_load == use_tma_load && + other->use_tma_store == use_tma_store; return attr_equal; } @@ -83,6 +89,7 @@ class TransposeParams : public HeuristicParams { ss << " input tile size: " << tile_size1 << "\n"; ss << " output tile size: " << tile_size2 << "\n"; ss << " use_tma_load: " << use_tma_load << "\n"; + ss << " use_tma_store: " << use_tma_store << "\n"; int64_t elements_per_tile = tile_size1 * tile_size2; ss << " elements per tile: " << elements_per_tile << "\n"; int64_t elements_per_thread = elements_per_tile / lparams.bdimx(); @@ -154,7 +161,8 @@ class TransposeParams : public HeuristicParams { vectorize_factor2, tile_size1, tile_size2, - use_tma_load); + use_tma_load, + use_tma_store); } std::unique_ptr clone() const override { diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index d90deb83f81..7abaa3c8a6f 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1431,4 +1431,101 @@ TEST_F(TransposeTest, Tma) { auto outputs = executor_cache.runFusionWithInputs({input0}); testValidate(executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); } + +TEST_F(TransposeTest, TmaTranspose) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::Float; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + fusion.addOutput(tv1); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + + bool auto_schedule = false; + if (std::getenv("USE_AUTO") != nullptr) { + auto_schedule = true; + } + if (auto_schedule) { + // H100, non-tma, 82%, tma 71% + // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 + // tma, load and store, 71% H100 + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0}); + testValidate( + executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); + } else { + auto tv0_smem = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + auto tv1_regs = tv1->cacheBefore(); + tv0_smem->setMemoryType(MemoryType::Shared); + int64_t t0 = 32, t1 = 32; + int64_t unroll = 2, vectorization = 4; + int64_t bdimy = t0 / unroll, bdimx = t1 / vectorization; + NVF_ERROR(bdimx % 32 == 0); + NVF_ERROR(bdimy * unroll == t0); + // [i0, i1] + tv0_smem->split(1, t1); + tv0_smem->split(0, t0); + // [i0/t0, t0, i1/t1, t1] -> [i0/t0 * i1/t1, t0, t1] + tv0_smem->reorder({{1, 2}}); + tv0_smem->merge(0); + tv0_smem->axis(0)->parallelize(ParallelType::BIDx); + // Schedule the memory format for 128 byte swizzle + // [BIDx, t0, t1] + tv0_smem->split(1, 8); + // [BIDx, 4, 8, 8', 4'] + tv0_smem->swizzle(SwizzleType::XOR, 2, 3); + // [BIDx, 4, 8, 8', 4'] + tv0_smem->setAllocationDomain(tv0_smem->getLoopDomain(), true); + tv0_smem->axis(1)->parallelize(ParallelType::Bulk); + tv0_smem->axis(2)->parallelize(ParallelType::Bulk); + tv0_smem->axis(3)->parallelize(ParallelType::Bulk); + tv0_smem->axis(4)->parallelize(ParallelType::Bulk); + // Apply TMA swizzle to avoid bank conflicts on transpose + // The XOR swizzle makes BOTH row-wise and column-wise access + // bank-conflict-free So we don't need ldmatrix/stmatrix for a simple + // transpose! + MmaInputSmemSwizzle swizzle_type = + mma_utils::tmaSwizzleSharedMemory(tv0_smem); + tv0_smem->applyMmaSwizzleForTMALoad(swizzle_type); + + // propagate to other tensors + TransformPropagator propagator(tv0_smem); + MaxLogicalDomainInfoSpanningTree(tv0_smem).traverse(&propagator); + + // schedule tv0_smem + tv0_smem->axis(0)->parallelize(ParallelType::BIDy); + tv0_smem->axis(1)->parallelize(ParallelType::BIDx); + scheduler_utils::parallelizeAllLike(tv0_smem); + + // bulk parallelize tv0 + tv0_smem->axis(2)->parallelize(ParallelType::Bulk); + tv0_smem->axis(3)->parallelize(ParallelType::Bulk); + + // tidx and tidy + // With swizzled memory, these accesses are now bank-conflict-free! + for (auto tv : {tv1_regs, tv1}) { + // [i0/t0, i1/t1, t0, t1] -> [i0/t0, i1/t1, u, t0/u, t1/v, v] + tv->split(3, vectorization); + tv->split(2, unroll, false); + tv->axis(2)->parallelize(ParallelType::Unroll); + tv->axis(3)->parallelize(ParallelType::TIDy); + tv->axis(4)->parallelize(ParallelType::TIDx); + if (tv == tv1) { + tv->axis(5)->parallelize(ParallelType::Vectorize); + } + } + inlineMost(); + fusion.print(); + KernelExecutor ke; + ke.compile(&fusion, {input0}); + auto outputs = ke.run({input0}); + testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); + } +} } // namespace nvfuser diff --git a/tests/cpp/test_tutorial.cpp b/tests/cpp/test_tutorial.cpp index 1acb7209f91..4d3e5fd5f76 100644 --- a/tests/cpp/test_tutorial.cpp +++ b/tests/cpp/test_tutorial.cpp @@ -1497,6 +1497,7 @@ TEST_F(Tutorial, TMABankConflictFreeTranspose) { output->axis(1)->parallelize(ParallelType::Bulk); output->axis(2)->parallelize(ParallelType::Bulk); // [BIDx, Bulk, Bulk] + fusion.print(); // output_smem_cache and output_reg_cache are scheduled in the same way. // We use each warp to load one column of input_smem_cache. We vectorize From 411654f3e12468820c9f8cc9708eaa4511bea64d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 30 Jan 2026 08:31:21 -0800 Subject: [PATCH 03/11] 128 --- csrc/scheduler/transpose.cpp | 75 +++++++++++++++++++++++++++--------- tests/cpp/test_transpose.cpp | 3 +- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index c972e219ec7..3d4ec0b06ed 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -963,14 +963,17 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { std::cout << "reference1: " << reference1->toString() << std::endl; std::cout << "reference2: " << reference2->toString() << std::endl; - fusion->print(); - // scheduler reference1 - output_reference->split(1, 32); - output_reference->split(0, 32); + // output: [I0, I1] -> [I0/tile_0 * I1/tile_1, tile_0, tile_1] + // smem -> register is vectorized + // register -> smem is unrolled + int64_t tile_0 = 32, tile_1 = 32, n_warps = 4, unroll_vect = 4; + // int64_t n_serial_loop = tile_0 / unroll_vect / n_warps; + output_reference->split(1, tile_1); + output_reference->split(0, tile_0); output_reference->reorder({{-2, 0}}); output_reference->merge(0); - // [I0/32 * I1/32', 32', 32] + // [I0/tile_0 * I1/tile_1, tile_0, tile_1] output_reference->axis(0)->parallelize(ParallelType::BIDx); using Options = scheduler_utils::BoundedDirectionalTransformPropagator::Options; @@ -984,38 +987,71 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { output_reference->axis(1)->parallelize(ParallelType::Bulk); output_reference->axis(2)->parallelize(ParallelType::Bulk); + fusion->printMath(); + for (auto output_smem_cache : output_smem_tvs) { // [BIDx, 32', 32] output_smem_cache->setAllocationDomain( output_smem_cache->getLoopDomain(), true); - output_smem_cache->split(1, 4); - // [BIDx, 8', 4', 32] + // [BDIx, tile_0, tile_1] -> [BDIx, tile_0/unroll_vect, unroll_vect, tile_1] + output_smem_cache->split(1, unroll_vect); + // [BDIx, tile_0/unroll_vect/n_warps, n_warps, unroll_vect, tile_1] + output_smem_cache->split(1, n_warps); + scheduler_utils::BoundedDirectionalTransformPropagator::backward( output_smem_cache, -1, {input_reference}); - output_smem_cache->merge(1, 3); - // [BIDx, 256, 4'] - output_smem_cache->axis(1)->parallelize(ParallelType::TIDx); + output_smem_cache->axis(2)->parallelize(ParallelType::TIDy); + output_smem_cache->axis(4)->parallelize(ParallelType::TIDx); scheduler_utils::BoundedDirectionalTransformPropagator::backward( output_smem_cache, -1, {input_smem_tvs.at(0)}, Options{}.propagateParallelType()); - output_smem_cache->axis(2)->parallelize(ParallelType::Unroll); - output_reg_cache->axis(2)->parallelize(ParallelType::Vectorize); + output_smem_cache->axis(3)->parallelize(ParallelType::Unroll); + output_reg_cache->axis(3)->parallelize(ParallelType::Vectorize); output_reg_cache->setAllocationDomain( output_reg_cache->getLoopDomain(), true); } for (auto input_smem_cache : input_smem_tvs) { // Schedule the memory format for 128 byte swizzle - // [BIDx, 8', 4', 32] + // After backward propagation and reorder: + // [BIDx, tile_1, tile_0/unroll_vect/n_warps, n_warps, unroll_vect] + // = [BIDx, 32, 2, 4, 4] input_smem_cache->reorder({{-1, 1}}); - // [BIDx, 32, 8', 4'] - input_smem_cache->split(1, 8); - // [BIDx, 4, 8, 8', 4'] - input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); - // [BIDx, 4, 8, 8', 4'] - input_smem_cache->setAllocationDomain( + + :cout << "input_smem_cache after reorder: " << + nput_smem_cache->toString() << std::endl; + + tile_1(32) by 8 to create first dimension of size 8 + // [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4] + input_smem_cache->split(1, 8); + + // + + 2×4 to create second dimension of size 8 + // [BIDx, 4, 8, 2, 4, 4] -> [BIDx, 4, 8, 8, 4] + input_smem_cache->merge(3, 4); + + std::c + + put_smem_cache before swizzle : " << input_smem_c + che->toString() + << std::endl; + + // Swizzle + + 's at positions 2 and 3 + // [BIDx, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 + input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); + + // Set allocat + + to match the swizzled layout input_smem_cache->setAllocationDomain( input_smem_cache->getLoopDomain(), true); + + // Parallelize all + + s as Bulk for TMA input_smem_cache->axis(1)->parallelize(ParallelType::Bulk); input_smem_cache->axis(2)->parallelize(ParallelType::Bulk); input_smem_cache->axis(3)->parallelize(ParallelType::Bulk); @@ -1537,3 +1573,4 @@ void TransposeScheduler::schedule( scheduleTranspose(fusion, tparams); } } // namespace nvfuser + diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 7abaa3c8a6f..9d1a251e058 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1452,9 +1452,8 @@ TEST_F(TransposeTest, TmaTranspose) { auto_schedule = true; } if (auto_schedule) { - // H100, non-tma, 82%, tma 71% + // H100, non-tma, 82%, tma-256, 71%, tma-128, 78% // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 - // tma, load and store, 71% H100 FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto outputs = executor_cache.runFusionWithInputs({input0}); testValidate( From e1c90ccf024b4852541f380ffeac8471416de060 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 30 Jan 2026 08:33:13 -0800 Subject: [PATCH 04/11] 128 --- csrc/scheduler/transpose.cpp | 43 ++++++++++++++---------------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 3d4ec0b06ed..2f7f7f297ac 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -1019,39 +1019,29 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { // = [BIDx, 32, 2, 4, 4] input_smem_cache->reorder({{-1, 1}}); - :cout << "input_smem_cache after reorder: " << - nput_smem_cache->toString() << std::endl; + std::cout << "input_smem_cache after reorder: " + << input_smem_cache->toString() << std::endl; - tile_1(32) by 8 to create first dimension of size 8 - // [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4] - input_smem_cache->split(1, 8); + // Split tile_1 (32) by 8 to create first dimension of size 8 + // [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4] + input_smem_cache->split(1, 8); - // - - 2×4 to create second dimension of size 8 - // [BIDx, 4, 8, 2, 4, 4] -> [BIDx, 4, 8, 8, 4] - input_smem_cache->merge(3, 4); - - std::c + // Merge the 2×4 to create second dimension of size 8 + // [BIDx, 4, 8, 2, 4, 4] -> [BIDx, 4, 8, 8, 4] + input_smem_cache->merge(3, 4); - put_smem_cache before swizzle : " << input_smem_c - che->toString() - << std::endl; + std::cout << "input_smem_cache before swizzle: " + << input_smem_cache->toString() << std::endl; - // Swizzle + // Swizzle the two 8's at positions 2 and 3 + // [BIDx, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 + input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); - 's at positions 2 and 3 - // [BIDx, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 - input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); - - // Set allocat - - to match the swizzled layout input_smem_cache->setAllocationDomain( + // Set allocation domain to match the swizzled layout + input_smem_cache->setAllocationDomain( input_smem_cache->getLoopDomain(), true); - // Parallelize all - - s as Bulk for TMA + // Parallelize all dimensions as Bulk for TMA input_smem_cache->axis(1)->parallelize(ParallelType::Bulk); input_smem_cache->axis(2)->parallelize(ParallelType::Bulk); input_smem_cache->axis(3)->parallelize(ParallelType::Bulk); @@ -1573,4 +1563,3 @@ void TransposeScheduler::schedule( scheduleTranspose(fusion, tparams); } } // namespace nvfuser - From 8f1d8f7382b10f2ecd0c2d27a19d311840328e92 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 30 Jan 2026 11:20:00 -0800 Subject: [PATCH 05/11] multiple tma loads per cta --- csrc/scheduler/transpose.cpp | 47 +++++++++++++++++++++--------------- tests/cpp/test_transpose.cpp | 2 +- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 2f7f7f297ac..e4557c3fdfe 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -968,13 +968,17 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { // smem -> register is vectorized // register -> smem is unrolled int64_t tile_0 = 32, tile_1 = 32, n_warps = 4, unroll_vect = 4; + int64_t tiles_per_block = 4; // int64_t n_serial_loop = tile_0 / unroll_vect / n_warps; output_reference->split(1, tile_1); output_reference->split(0, tile_0); output_reference->reorder({{-2, 0}}); output_reference->merge(0); - // [I0/tile_0 * I1/tile_1, tile_0, tile_1] - output_reference->axis(0)->parallelize(ParallelType::BIDx); + // [I0/tile_0 * I1/tile_1/tiles_per_block, tiles_per_block, tile_0, tile_1] + output_reference->split(0, tiles_per_block); + int64_t pos_bdimx = 0, tile0_pos = 2; + output_reference->axis(pos_bdimx)->parallelize(ParallelType::BIDx); + // output_reference->axis(pos_bdimx + 1)->parallelize(ParallelType::Unroll); using Options = scheduler_utils::BoundedDirectionalTransformPropagator::Options; scheduler_utils::BoundedDirectionalTransformPropagator::backward( @@ -984,8 +988,8 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { Options{}.propagateParallelType()); // For fusion output, we just use TMA to store the entire tile back to global // memory. There is no need to further schedule the output tensor. - output_reference->axis(1)->parallelize(ParallelType::Bulk); - output_reference->axis(2)->parallelize(ParallelType::Bulk); + output_reference->axis(tile0_pos)->parallelize(ParallelType::Bulk); + output_reference->axis(tile0_pos + 1)->parallelize(ParallelType::Bulk); fusion->printMath(); @@ -994,21 +998,21 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { output_smem_cache->setAllocationDomain( output_smem_cache->getLoopDomain(), true); // [BDIx, tile_0, tile_1] -> [BDIx, tile_0/unroll_vect, unroll_vect, tile_1] - output_smem_cache->split(1, unroll_vect); + output_smem_cache->split(tile0_pos, unroll_vect); // [BDIx, tile_0/unroll_vect/n_warps, n_warps, unroll_vect, tile_1] - output_smem_cache->split(1, n_warps); + output_smem_cache->split(tile0_pos, n_warps); scheduler_utils::BoundedDirectionalTransformPropagator::backward( output_smem_cache, -1, {input_reference}); - output_smem_cache->axis(2)->parallelize(ParallelType::TIDy); - output_smem_cache->axis(4)->parallelize(ParallelType::TIDx); + output_smem_cache->axis(tile0_pos + 1)->parallelize(ParallelType::TIDy); + output_smem_cache->axis(tile0_pos + 3)->parallelize(ParallelType::TIDx); scheduler_utils::BoundedDirectionalTransformPropagator::backward( output_smem_cache, -1, {input_smem_tvs.at(0)}, Options{}.propagateParallelType()); - output_smem_cache->axis(3)->parallelize(ParallelType::Unroll); - output_reg_cache->axis(3)->parallelize(ParallelType::Vectorize); + output_smem_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Unroll); + output_reg_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Vectorize); output_reg_cache->setAllocationDomain( output_reg_cache->getLoopDomain(), true); } @@ -1017,38 +1021,41 @@ void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { // After backward propagation and reorder: // [BIDx, tile_1, tile_0/unroll_vect/n_warps, n_warps, unroll_vect] // = [BIDx, 32, 2, 4, 4] - input_smem_cache->reorder({{-1, 1}}); + int64_t tile1_pos = 2; + input_smem_cache->reorder({{-1, 2}}); std::cout << "input_smem_cache after reorder: " << input_smem_cache->toString() << std::endl; // Split tile_1 (32) by 8 to create first dimension of size 8 // [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4] - input_smem_cache->split(1, 8); + input_smem_cache->split(tile1_pos, 8); // Merge the 2×4 to create second dimension of size 8 - // [BIDx, 4, 8, 2, 4, 4] -> [BIDx, 4, 8, 8, 4] - input_smem_cache->merge(3, 4); + // [BIDx, tilesPerCTA, 4, 8, 2, 4, 4] -> [BIDx, tilesPerCTA, 4, 8, 8, 4] + input_smem_cache->merge(tile1_pos + 2, tile1_pos + 3); std::cout << "input_smem_cache before swizzle: " << input_smem_cache->toString() << std::endl; // Swizzle the two 8's at positions 2 and 3 - // [BIDx, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 - input_smem_cache->swizzle(SwizzleType::XOR, 2, 3); + // [BIDx, tilesPerCTA, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 + input_smem_cache->swizzle(SwizzleType::XOR, tile1_pos + 1, tile1_pos + 2); // Set allocation domain to match the swizzled layout input_smem_cache->setAllocationDomain( input_smem_cache->getLoopDomain(), true); // Parallelize all dimensions as Bulk for TMA - input_smem_cache->axis(1)->parallelize(ParallelType::Bulk); - input_smem_cache->axis(2)->parallelize(ParallelType::Bulk); - input_smem_cache->axis(3)->parallelize(ParallelType::Bulk); - input_smem_cache->axis(4)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 1)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 2)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 3)->parallelize(ParallelType::Bulk); // [BIDx, Bulk, Bulk, Bulk, Bulk] } + inlineMost(); + fusion->print(); } diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 9d1a251e058..8f1028474ca 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1452,7 +1452,7 @@ TEST_F(TransposeTest, TmaTranspose) { auto_schedule = true; } if (auto_schedule) { - // H100, non-tma, 82%, tma-256, 71%, tma-128, 78% + // H100, non-tma, 82%, tma-256, 71%, tma-128, 78%, tma-128-2tilesPerCTA, 80% // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto outputs = executor_cache.runFusionWithInputs({input0}); From 0d336d930deec902ed0b11a75909a0d73143179e Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 31 Jan 2026 17:16:36 -0800 Subject: [PATCH 06/11] manual schedule --- tests/cpp/test_transpose.cpp | 106 +++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 8f1028474ca..a0d24c784a5 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1527,4 +1527,110 @@ TEST_F(TransposeTest, TmaTranspose) { testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); } } + +TEST_F(TransposeTest, TmaTransposeSimple) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::Float; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + fusion.addOutput(tv1); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + bool auto_schedule = false; + if (std::getenv("USE_AUTO") != nullptr) { + auto_schedule = true; + } + if (auto_schedule) { + // H100, non-tma, 82%, tma-256, 71%, tma-128, 78%, tma-128-2tilesPerCTA, 80% + // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0}); + testValidate( + executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); + } else { + auto input_cache = tv0->cacheAfter(); + auto output_cache = tv1->cacheBefore(); + input_cache->setMemoryType(MemoryType::Shared); + + // global schedule + auto reference1 = tv1; + int64_t inner_most_pos1_in_ref1 = 1, inner_most_pos2_in_ref1 = 0; + int64_t tile_size1 = 32, tile_size2 = 32; + // [i1, i2] -> [i1, i2/tile1, tile1] + reference1->split(inner_most_pos1_in_ref1, tile_size1); + reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}}); + reference1->split(inner_most_pos2_in_ref1, tile_size2); + reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}}); + reference1->merge(0); + + int64_t rhs_i = 0; + reference1->split(rhs_i, 1); + // [r.., merged_dim, 1, tile1, tile2] + + // parallelize non-tile dimensions + reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); + reference1->axis(rhs_i)->parallelize(ParallelType::BIDx); + // [r.., BIDx, Unswitch, tile1, tile2] + + // Propagate transformations so far to the entire DAG + TransformPropagator propagator(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); + + scheduler_utils::parallelizeAllLike( + reference1, + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + + std::cout << output_cache->toString() << std::endl; + ////////////////////////////// + // Step 3: Schedule group 2 // + ////////////////////////////// + // [BIDx, Unswitch, tile1, tile2] + int64_t pos = 2; + int64_t vectorize_factor = 4, threads_per_block = 128; + for (auto tv : {tv0, input_cache}) { + tv->merge(pos); + tv->split(pos, vectorize_factor); + tv->split(pos, threads_per_block); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + tv->axis(2)->parallelize(ParallelType::Unroll); + tv->axis(3)->parallelize(ParallelType::TIDx); + if (tv == input_cache) { + tv->axis(4)->parallelize(ParallelType::Vectorize); + } + } + ////////////////////////////// + // Step 4: Schedule group 1 // + ////////////////////////////// + // [BIDx, Unswitch, tile1, tile2] + for (auto tv : {output_cache, tv1}) { + tv->reorder({{-2, -1}}); + // [..., tile2, tile1] + tv->merge(pos); + tv->split(pos, vectorize_factor); + tv->split(pos, threads_per_block); + tv->axis(2)->parallelize(ParallelType::Unroll); + tv->axis(3)->parallelize(ParallelType::TIDx); + if (tv == tv1) { + tv->axis(4)->parallelize(ParallelType::Vectorize); + } + } + inlineMost(); + fusion.print(); + KernelExecutor ke; + ke.compile(&fusion, {input0}); + auto outputs = ke.run({input0}); + testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); + } +} + } // namespace nvfuser From ce4e394ff60d0d499c2e18efab0bed2b2eed681d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 31 Jan 2026 17:54:01 -0800 Subject: [PATCH 07/11] clean --- tests/cpp/test_transpose.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index a0d24c784a5..4529da91fc1 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1560,22 +1560,20 @@ TEST_F(TransposeTest, TmaTransposeSimple) { // global schedule auto reference1 = tv1; - int64_t inner_most_pos1_in_ref1 = 1, inner_most_pos2_in_ref1 = 0; int64_t tile_size1 = 32, tile_size2 = 32; // [i1, i2] -> [i1, i2/tile1, tile1] - reference1->split(inner_most_pos1_in_ref1, tile_size1); - reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}}); - reference1->split(inner_most_pos2_in_ref1, tile_size2); - reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}}); + reference1->split(1, tile_size1); + reference1->reorder({{2, -1}}); + reference1->split(0, tile_size2); + reference1->reorder({{1, -1}}); reference1->merge(0); - int64_t rhs_i = 0; - reference1->split(rhs_i, 1); + reference1->split(0, 1); // [r.., merged_dim, 1, tile1, tile2] // parallelize non-tile dimensions - reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); - reference1->axis(rhs_i)->parallelize(ParallelType::BIDx); + reference1->axis(1)->parallelize(ParallelType::Unswitch); + reference1->axis(0)->parallelize(ParallelType::BIDx); // [r.., BIDx, Unswitch, tile1, tile2] // Propagate transformations so far to the entire DAG From 33055d23d028a4a783e1c760dec844b8b17d9597 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 31 Jan 2026 18:07:46 -0800 Subject: [PATCH 08/11] two options --- tests/cpp/test_transpose.cpp | 101 ++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 4529da91fc1..ae8e6eb7457 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1558,41 +1558,78 @@ TEST_F(TransposeTest, TmaTransposeSimple) { auto output_cache = tv1->cacheBefore(); input_cache->setMemoryType(MemoryType::Shared); - // global schedule - auto reference1 = tv1; - int64_t tile_size1 = 32, tile_size2 = 32; - // [i1, i2] -> [i1, i2/tile1, tile1] - reference1->split(1, tile_size1); - reference1->reorder({{2, -1}}); - reference1->split(0, tile_size2); - reference1->reorder({{1, -1}}); - reference1->merge(0); - - reference1->split(0, 1); - // [r.., merged_dim, 1, tile1, tile2] - - // parallelize non-tile dimensions - reference1->axis(1)->parallelize(ParallelType::Unswitch); - reference1->axis(0)->parallelize(ParallelType::BIDx); - // [r.., BIDx, Unswitch, tile1, tile2] - - // Propagate transformations so far to the entire DAG - TransformPropagator propagator(reference1); - MaxLogicalDomainInfoSpanningTree entire_dag(reference1); - entire_dag.traverse(&propagator); - - scheduler_utils::parallelizeAllLike( - reference1, - /*selected_tvs=*/{}, - /*selected_parallel_types=*/{}, - /*propagate_padding=*/true, - /*parallelize_inputs_on_did=*/true); - + int64_t tile_size1 = 32, tile_size2 = 16; + bool use_propagate_schedule = true; + if (std::getenv("USE_PROPAGATE") != nullptr) { + use_propagate_schedule = true; + } + if (std::getenv("USE_PER_TENSOR") != nullptr) { + use_propagate_schedule = false; + } + if (use_propagate_schedule) { + // Propagate-based schedule from a single reference. + auto reference1 = tv1; + // [i1, i2] -> [i1, i2/tile1, tile1] + reference1->split(1, tile_size1); + reference1->reorder({{2, -1}}); + reference1->split(0, tile_size2); + reference1->reorder({{1, -1}}); + reference1->merge(0); + + reference1->split(0, 1); + // [r.., merged_dim, 1, tile1, tile2] + + // parallelize non-tile dimensions + reference1->axis(1)->parallelize(ParallelType::Unswitch); + reference1->axis(0)->parallelize(ParallelType::BIDx); + // [r.., BIDx, Unswitch, tile1, tile2] + + // Propagate transformations so far to the entire DAG + TransformPropagator propagator(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); + + scheduler_utils::parallelizeAllLike( + reference1, + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + } else { + // Per-tensor schedule mirroring TransformPrinter order in 1.log. + // Group-2 tensors: input side uses [x, y] layout. + for (auto tv : {tv0, input_cache}) { + // [x, y] -> [x/32, 32, y/16, 16] + tv->split(1, tile_size2); + tv->split(0, tile_size1); + // [y/16, x/32, 32, 16] + tv->reorder({{2, 0}, {0, 1}, {1, 2}}); + // [y/16 * x/32, 32, 16] + tv->merge(0); + tv->split(0, 1); + tv->axis(1)->parallelize(ParallelType::Unswitch); + tv->axis(0)->parallelize(ParallelType::BIDx); + } + // Group-1 tensors: output side uses [y, x] layout. + for (auto tv : {output_cache, tv1}) { + // [y, x] -> [y/16, 16, x/32, 32] + tv->split(1, tile_size1); + tv->split(0, tile_size2); + // [y/16, x/32, 32, 16] + tv->reorder({{1, 3}, {3, 2}, {2, 1}}); + // [y/16 * x/32, 32, 16] + tv->merge(0); + tv->split(0, 1); + tv->axis(1)->parallelize(ParallelType::Unswitch); + tv->axis(0)->parallelize(ParallelType::BIDx); + } + } + // Print to verify the loop domain matches the expected transform order. std::cout << output_cache->toString() << std::endl; ////////////////////////////// // Step 3: Schedule group 2 // ////////////////////////////// - // [BIDx, Unswitch, tile1, tile2] + // Vectorize/Unroll for group 2. Only the shared cache is vectorized. int64_t pos = 2; int64_t vectorize_factor = 4, threads_per_block = 128; for (auto tv : {tv0, input_cache}) { @@ -1609,7 +1646,7 @@ TEST_F(TransposeTest, TmaTransposeSimple) { ////////////////////////////// // Step 4: Schedule group 1 // ////////////////////////////// - // [BIDx, Unswitch, tile1, tile2] + // Vectorize/Unroll for group 1. Only the output is vectorized. for (auto tv : {output_cache, tv1}) { tv->reorder({{-2, -1}}); // [..., tile2, tile1] From 903ab45b55febb3b0fbd7d53ab9376b32155b308 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 2 Feb 2026 08:28:38 -0800 Subject: [PATCH 09/11] add swizzle, still has conflict --- tests/cpp/test_transpose.cpp | 72 +++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index ae8e6eb7457..2874f996fa1 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1558,14 +1558,9 @@ TEST_F(TransposeTest, TmaTransposeSimple) { auto output_cache = tv1->cacheBefore(); input_cache->setMemoryType(MemoryType::Shared); - int64_t tile_size1 = 32, tile_size2 = 16; - bool use_propagate_schedule = true; - if (std::getenv("USE_PROPAGATE") != nullptr) { - use_propagate_schedule = true; - } - if (std::getenv("USE_PER_TENSOR") != nullptr) { - use_propagate_schedule = false; - } + int64_t tile_size1 = 32, tile_size2 = 32; + bool use_propagate_schedule = false; + if (use_propagate_schedule) { // Propagate-based schedule from a single reference. auto reference1 = tv1; @@ -1596,32 +1591,35 @@ TEST_F(TransposeTest, TmaTransposeSimple) { /*propagate_padding=*/true, /*parallelize_inputs_on_did=*/true); } else { - // Per-tensor schedule mirroring TransformPrinter order in 1.log. - // Group-2 tensors: input side uses [x, y] layout. - for (auto tv : {tv0, input_cache}) { - // [x, y] -> [x/32, 32, y/16, 16] - tv->split(1, tile_size2); - tv->split(0, tile_size1); - // [y/16, x/32, 32, 16] - tv->reorder({{2, 0}, {0, 1}, {1, 2}}); - // [y/16 * x/32, 32, 16] + // Group-1 tensors: output side uses [y, x] layout. + for (auto tv : {output_cache, tv1}) { + // [y, x] -> [y/tile_size2, tile_size2, x/tile_size1, tile_size1] + tv->split(1, tile_size1); + tv->split(0, tile_size2); + // [x/tile_size1, y/tile_size2, tile_size1, tile_size2] + tv->reorder({{0, 1}, {1, 3}, {2, 0}, {3, 2}}); + std::cout << tv->toString() << std::endl; + // [x/tile_size1 * y/tile_size2, tile_size1, tile_size2] tv->merge(0); tv->split(0, 1); tv->axis(1)->parallelize(ParallelType::Unswitch); tv->axis(0)->parallelize(ParallelType::BIDx); + std::cout << tv->toString() << std::endl; } - // Group-1 tensors: output side uses [y, x] layout. - for (auto tv : {output_cache, tv1}) { - // [y, x] -> [y/16, 16, x/32, 32] - tv->split(1, tile_size1); - tv->split(0, tile_size2); - // [y/16, x/32, 32, 16] - tv->reorder({{1, 3}, {3, 2}, {2, 1}}); - // [y/16 * x/32, 32, 16] + // Group-2 tensors: input side uses [x, y] layout. + for (auto tv : {tv0, input_cache}) { + // [x, y] -> [x/tile_size1, tile_size1, y/tile_size2, tile_size2] + tv->split(1, tile_size2); + tv->split(0, tile_size1); + // [x/tile_size1, y/tile_size2, tile_size1, tile_size2] + tv->reorder({{1, 2}, {2, 1}}); + std::cout << tv->toString() << std::endl; + // [x/tile_size1 * y/tile_size2, tile_size1, tile_size2] tv->merge(0); tv->split(0, 1); tv->axis(1)->parallelize(ParallelType::Unswitch); tv->axis(0)->parallelize(ParallelType::BIDx); + std::cout << tv->toString() << std::endl; } } // Print to verify the loop domain matches the expected transform order. @@ -1629,19 +1627,33 @@ TEST_F(TransposeTest, TmaTransposeSimple) { ////////////////////////////// // Step 3: Schedule group 2 // ////////////////////////////// - // Vectorize/Unroll for group 2. Only the shared cache is vectorized. int64_t pos = 2; int64_t vectorize_factor = 4, threads_per_block = 128; - for (auto tv : {tv0, input_cache}) { + // schedule input cache + // [BIDx, Unswitch, tile_size1, tile_size2] + input_cache->split(3, 4); + // [BIDx, Unswitch, tile_size1, tile_size2/4, 4] + input_cache->split(2, 8); + // [BIDx, Unswitch, tile_size1/8, 8, tile_size2/4, 4] + input_cache->swizzle(SwizzleType::XOR, 3, 4); + input_cache->merge(2); + input_cache->merge(2); + input_cache->split(2, threads_per_block); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + input_cache->setAllocationDomain(input_cache->getLoopDomain(), true); + input_cache->axis(2)->parallelize(ParallelType::Unroll); + input_cache->axis(3)->parallelize(ParallelType::TIDx); + input_cache->axis(4)->parallelize(ParallelType::Vectorize); + // [BIDx, Bulk, Bulk, Bulk, Bulk] + // Vectorize/Unroll for group 2. Only the shared cache is vectorized. + + for (auto tv : {tv0}) { tv->merge(pos); tv->split(pos, vectorize_factor); tv->split(pos, threads_per_block); // [BIDx, Unswitch, Unroll, TIDx, Vectorize] tv->axis(2)->parallelize(ParallelType::Unroll); tv->axis(3)->parallelize(ParallelType::TIDx); - if (tv == input_cache) { - tv->axis(4)->parallelize(ParallelType::Vectorize); - } } ////////////////////////////// // Step 4: Schedule group 1 // From 9c28b62599046e8399aac5080f27c9af4b093443 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 2 Feb 2026 10:34:10 -0800 Subject: [PATCH 10/11] no bank conflicts --- tests/cpp/test_transpose.cpp | 108 +++-------------------------------- 1 file changed, 7 insertions(+), 101 deletions(-) diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 2874f996fa1..218a7a1e27c 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1432,102 +1432,6 @@ TEST_F(TransposeTest, Tma) { testValidate(executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); } -TEST_F(TransposeTest, TmaTranspose) { - auto fusion_ptr = std::make_unique(); - FusionGuard fg(fusion_ptr.get()); - Fusion& fusion = *fusion_ptr; - - auto dtype = DataType::Float; - auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); - fusion.addInput(tv0); - auto tv1 = transpose(tv0, 0, 1); - fusion.addOutput(tv1); - - auto options = - at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({262144, 5120}, options); - - bool auto_schedule = false; - if (std::getenv("USE_AUTO") != nullptr) { - auto_schedule = true; - } - if (auto_schedule) { - // H100, non-tma, 82%, tma-256, 71%, tma-128, 78%, tma-128-2tilesPerCTA, 80% - // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({input0}); - testValidate( - executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); - } else { - auto tv0_smem = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); - auto tv1_regs = tv1->cacheBefore(); - tv0_smem->setMemoryType(MemoryType::Shared); - int64_t t0 = 32, t1 = 32; - int64_t unroll = 2, vectorization = 4; - int64_t bdimy = t0 / unroll, bdimx = t1 / vectorization; - NVF_ERROR(bdimx % 32 == 0); - NVF_ERROR(bdimy * unroll == t0); - // [i0, i1] - tv0_smem->split(1, t1); - tv0_smem->split(0, t0); - // [i0/t0, t0, i1/t1, t1] -> [i0/t0 * i1/t1, t0, t1] - tv0_smem->reorder({{1, 2}}); - tv0_smem->merge(0); - tv0_smem->axis(0)->parallelize(ParallelType::BIDx); - // Schedule the memory format for 128 byte swizzle - // [BIDx, t0, t1] - tv0_smem->split(1, 8); - // [BIDx, 4, 8, 8', 4'] - tv0_smem->swizzle(SwizzleType::XOR, 2, 3); - // [BIDx, 4, 8, 8', 4'] - tv0_smem->setAllocationDomain(tv0_smem->getLoopDomain(), true); - tv0_smem->axis(1)->parallelize(ParallelType::Bulk); - tv0_smem->axis(2)->parallelize(ParallelType::Bulk); - tv0_smem->axis(3)->parallelize(ParallelType::Bulk); - tv0_smem->axis(4)->parallelize(ParallelType::Bulk); - // Apply TMA swizzle to avoid bank conflicts on transpose - // The XOR swizzle makes BOTH row-wise and column-wise access - // bank-conflict-free So we don't need ldmatrix/stmatrix for a simple - // transpose! - MmaInputSmemSwizzle swizzle_type = - mma_utils::tmaSwizzleSharedMemory(tv0_smem); - tv0_smem->applyMmaSwizzleForTMALoad(swizzle_type); - - // propagate to other tensors - TransformPropagator propagator(tv0_smem); - MaxLogicalDomainInfoSpanningTree(tv0_smem).traverse(&propagator); - - // schedule tv0_smem - tv0_smem->axis(0)->parallelize(ParallelType::BIDy); - tv0_smem->axis(1)->parallelize(ParallelType::BIDx); - scheduler_utils::parallelizeAllLike(tv0_smem); - - // bulk parallelize tv0 - tv0_smem->axis(2)->parallelize(ParallelType::Bulk); - tv0_smem->axis(3)->parallelize(ParallelType::Bulk); - - // tidx and tidy - // With swizzled memory, these accesses are now bank-conflict-free! - for (auto tv : {tv1_regs, tv1}) { - // [i0/t0, i1/t1, t0, t1] -> [i0/t0, i1/t1, u, t0/u, t1/v, v] - tv->split(3, vectorization); - tv->split(2, unroll, false); - tv->axis(2)->parallelize(ParallelType::Unroll); - tv->axis(3)->parallelize(ParallelType::TIDy); - tv->axis(4)->parallelize(ParallelType::TIDx); - if (tv == tv1) { - tv->axis(5)->parallelize(ParallelType::Vectorize); - } - } - inlineMost(); - fusion.print(); - KernelExecutor ke; - ke.compile(&fusion, {input0}); - auto outputs = ke.run({input0}); - testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); - } -} - TEST_F(TransposeTest, TmaTransposeSimple) { auto fusion_ptr = std::make_unique(); FusionGuard fg(fusion_ptr.get()); @@ -1631,11 +1535,13 @@ TEST_F(TransposeTest, TmaTransposeSimple) { int64_t vectorize_factor = 4, threads_per_block = 128; // schedule input cache // [BIDx, Unswitch, tile_size1, tile_size2] - input_cache->split(3, 4); - // [BIDx, Unswitch, tile_size1, tile_size2/4, 4] - input_cache->split(2, 8); - // [BIDx, Unswitch, tile_size1/8, 8, tile_size2/4, 4] - input_cache->swizzle(SwizzleType::XOR, 3, 4); + input_cache->split(3, vectorize_factor); + // [BIDx, Unswitch, tile_size1, tile_size2/vectorize_factor, + // vectorize_factor] + input_cache->split(2, vectorize_factor); + // [BIDx, Unswitch, tile_size1/vectorize_factor, vectorize_factor, + // tile_size2/vectorize_factor, vectorize_factor] + input_cache->swizzle(SwizzleType::XOR, 2, 4); input_cache->merge(2); input_cache->merge(2); input_cache->split(2, threads_per_block); From 9707e932317b9c3ac1a9b225ea8980ec59bab15b Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 2 Feb 2026 11:40:42 -0800 Subject: [PATCH 11/11] add doc --- doc/dev/transpose_access_map.md | 159 ++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 doc/dev/transpose_access_map.md diff --git a/doc/dev/transpose_access_map.md b/doc/dev/transpose_access_map.md new file mode 100644 index 00000000000..66c61339e8d --- /dev/null +++ b/doc/dev/transpose_access_map.md @@ -0,0 +1,159 @@ +# Transpose 32x32 Access Maps (All Warps Combined) + +Cells show tids (3 digits); empty cells are `...`. + +**Summary** +- No bank conflicts: within each warp, the swizzle permutes lanes so each thread lands on a distinct bank for the shared accesses shown in `write to smem` and `regs (from smem read)`. +- Coalesced gmem: the input and output gmem plots show contiguous columns for each warp group, so global reads/writes are aligned and coalesced. + +## input gmem (read) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 033 ... ... ... 032 ... ... ... 035 ... ... ... 034 ... ... ... 037 ... ... ... 036 ... ... ... 039 ... ... ... 038 ... ... ... +005: 041 ... ... ... 040 ... ... ... 043 ... ... ... 042 ... ... ... 045 ... ... ... 044 ... ... ... 047 ... ... ... 046 ... ... ... +006: 049 ... ... ... 048 ... ... ... 051 ... ... ... 050 ... ... ... 053 ... ... ... 052 ... ... ... 055 ... ... ... 054 ... ... ... +007: 057 ... ... ... 056 ... ... ... 059 ... ... ... 058 ... ... ... 061 ... ... ... 060 ... ... ... 063 ... ... ... 062 ... ... ... +008: 066 ... ... ... 067 ... ... ... 064 ... ... ... 065 ... ... ... 070 ... ... ... 071 ... ... ... 068 ... ... ... 069 ... ... ... +009: 074 ... ... ... 075 ... ... ... 072 ... ... ... 073 ... ... ... 078 ... ... ... 079 ... ... ... 076 ... ... ... 077 ... ... ... +010: 082 ... ... ... 083 ... ... ... 080 ... ... ... 081 ... ... ... 086 ... ... ... 087 ... ... ... 084 ... ... ... 085 ... ... ... +011: 090 ... ... ... 091 ... ... ... 088 ... ... ... 089 ... ... ... 094 ... ... ... 095 ... ... ... 092 ... ... ... 093 ... ... ... +012: 099 ... ... ... 098 ... ... ... 097 ... ... ... 096 ... ... ... 103 ... ... ... 102 ... ... ... 101 ... ... ... 100 ... ... ... +013: 107 ... ... ... 106 ... ... ... 105 ... ... ... 104 ... ... ... 111 ... ... ... 110 ... ... ... 109 ... ... ... 108 ... ... ... +014: 115 ... ... ... 114 ... ... ... 113 ... ... ... 112 ... ... ... 119 ... ... ... 118 ... ... ... 117 ... ... ... 116 ... ... ... +015: 123 ... ... ... 122 ... ... ... 121 ... ... ... 120 ... ... ... 127 ... ... ... 126 ... ... ... 125 ... ... ... 124 ... ... ... +016: 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... +017: 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... +018: 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... +019: 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... +020: 037 ... ... ... 036 ... ... ... 039 ... ... ... 038 ... ... ... 033 ... ... ... 032 ... ... ... 035 ... ... ... 034 ... ... ... +021: 045 ... ... ... 044 ... ... ... 047 ... ... ... 046 ... ... ... 041 ... ... ... 040 ... ... ... 043 ... ... ... 042 ... ... ... +022: 053 ... ... ... 052 ... ... ... 055 ... ... ... 054 ... ... ... 049 ... ... ... 048 ... ... ... 051 ... ... ... 050 ... ... ... +023: 061 ... ... ... 060 ... ... ... 063 ... ... ... 062 ... ... ... 057 ... ... ... 056 ... ... ... 059 ... ... ... 058 ... ... ... +024: 070 ... ... ... 071 ... ... ... 068 ... ... ... 069 ... ... ... 066 ... ... ... 067 ... ... ... 064 ... ... ... 065 ... ... ... +025: 078 ... ... ... 079 ... ... ... 076 ... ... ... 077 ... ... ... 074 ... ... ... 075 ... ... ... 072 ... ... ... 073 ... ... ... +026: 086 ... ... ... 087 ... ... ... 084 ... ... ... 085 ... ... ... 082 ... ... ... 083 ... ... ... 080 ... ... ... 081 ... ... ... +027: 094 ... ... ... 095 ... ... ... 092 ... ... ... 093 ... ... ... 090 ... ... ... 091 ... ... ... 088 ... ... ... 089 ... ... ... +028: 103 ... ... ... 102 ... ... ... 101 ... ... ... 100 ... ... ... 099 ... ... ... 098 ... ... ... 097 ... ... ... 096 ... ... ... +029: 111 ... ... ... 110 ... ... ... 109 ... ... ... 108 ... ... ... 107 ... ... ... 106 ... ... ... 105 ... ... ... 104 ... ... ... +030: 119 ... ... ... 118 ... ... ... 117 ... ... ... 116 ... ... ... 115 ... ... ... 114 ... ... ... 113 ... ... ... 112 ... ... ... +031: 127 ... ... ... 126 ... ... ... 125 ... ... ... 124 ... ... ... 123 ... ... ... 122 ... ... ... 121 ... ... ... 120 ... ... ... +``` + +## write to smem + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +005: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +006: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +007: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +008: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +009: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +010: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +011: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +012: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +013: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +014: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +015: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +016: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +017: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +018: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +019: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +020: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +021: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +022: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +023: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +024: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +025: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +026: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +027: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +028: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +029: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +030: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +031: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +``` + +## regs (from smem read) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +001: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +002: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +003: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +004: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +005: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +006: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +007: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +008: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +009: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +010: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +011: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +012: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +013: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +014: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +015: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +016: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +017: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +018: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +019: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +020: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +021: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +022: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +023: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +024: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +025: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +026: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +027: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +028: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +029: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +030: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +031: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +``` + +## output gmem (write) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +005: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +006: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +007: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +008: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +009: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +010: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +011: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +012: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +013: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +014: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +015: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +016: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +017: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +018: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +019: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +020: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +021: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +022: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +023: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +024: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +025: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +026: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +027: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +028: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +029: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +030: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +031: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +```