diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 766213ca484..e4557c3fdfe 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -811,8 +811,13 @@ std::unique_ptr getTransposeHeuristics( grouped_inputs_outputs[1], max_unroll_factor); } - - tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx); + if (std::getenv("USE_TMA") != nullptr) { + tparams->use_tma_load = true; + tparams->use_tma_store = true; + } + if (!tparams->use_tma_load) { + tparams->lparams.bind(tparams->getThreadsPerBlock(), ParallelType::TIDx); + } if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << "\n===== Transpose Stats ========\n" @@ -843,7 +848,221 @@ std::unique_ptr getTransposeHeuristics( return tparams; } +void scheduleTransposeTMA(Fusion* fusion, const TransposeParams* tparams) { + FusionGuard fg(fusion); + + // Make sure we don't have global memory set on intermediate tensors from + // fusion segmentation + scheduler_utils::clearMemorySpace(fusion); + + // maybe has_reduction for scheduling should be done on a per output tensor + // basis. + NVF_ERROR( + !ir_utils::hasAnyReductionOps(fusion), + "This scheduler only handles pointwise ops."); + + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + + scheduler_utils::prepareForMemoryTypePromotion(fusion); + + std::vector input_tvs; + { + auto filtered_tvs = ir_utils::filterByType(fusion->inputs()); + // Remove hanging tensor views + for (auto tv : filtered_tvs) { + if (tv->uses().empty()) { + continue; + } + input_tvs.push_back(tv); + } + } + auto output_tvs = ir_utils::filterByType(fusion->outputs()); + + int64_t max_dims = 0; + for (auto inp : input_tvs) { + max_dims = std::max(scheduler_utils::nLogicalDims(inp), max_dims); + } + + for (auto out : output_tvs) { + max_dims = std::max(scheduler_utils::nLogicalDims(out), max_dims); + } + + // If everything is zero dim tensors, just return. + if (max_dims == 0) { + return; + } + + scheduler_tools::TransposeDomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + NVF_ERROR(grouped_inputs_outputs.size() >= 2); + + /* + * We need something similar to `cacheFork` for input tensors in group 2. We + * need this because we will want to propagate to the entire DAG except group + * 2 and its cached inputs, so we need to make sure the DAG is still connected + * if we remove group and its cached inputs. For example + * t0 + * | + * cache + * / \ + * t1 t2 + * if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will + * make it disconnected. + */ + std::vector input_smem_tvs; + std::vector output_smem_tvs; + TensorView* input_reference = nullptr; + std::unordered_set group2_and_cached_inputs( + grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + for (auto tv : grouped_inputs_outputs[1]) { + if (tv->isFusionInput()) { + input_reference = tv; + auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; + if (ir_utils::consumerTvsOf(existing_cache).size() > 1) { + auto new_cache = tv->cacheAfter(); + new_cache->setMemoryType(MemoryType::Shared); + input_smem_tvs.push_back(new_cache); + std::cout << "input_smem_tvs: " << new_cache->toString() << std::endl; + group2_and_cached_inputs.emplace(new_cache); + } else { + existing_cache->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + existing_cache->setMemoryType(MemoryType::Shared); + input_smem_tvs.push_back(existing_cache); + std::cout << "input_smem_tvs: " << existing_cache->toString() + << std::endl; + group2_and_cached_inputs.emplace(existing_cache); + } + } + } + // set cached outputs of group 2 to shared memory + TensorView* output_reference = nullptr; + TensorView* output_reg_cache = nullptr; + for (const auto& [cached_output, output_idx] : cached_outputs) { + auto output = fusion->outputs()[output_idx]->as(); + output_reference = output; + if (true || group2_and_cached_inputs.count(output) > 0) { + output->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + cached_output->setMemoryType(MemoryType::Shared); + output_smem_tvs.push_back(cached_output); + std::cout << "output_smem_tvs: " << cached_output->toString() + << std::endl; + output_reg_cache = cached_output->cacheBefore(); + } + } + + TensorView* reference1 = + domain_map.findReferenceFor(grouped_inputs_outputs[0]); + TensorView* reference2 = + domain_map.findReferenceFor(grouped_inputs_outputs[1]); + + std::cout << "reference1: " << reference1->toString() << std::endl; + std::cout << "reference2: " << reference2->toString() << std::endl; + + // output: [I0, I1] -> [I0/tile_0 * I1/tile_1, tile_0, tile_1] + // smem -> register is vectorized + // register -> smem is unrolled + int64_t tile_0 = 32, tile_1 = 32, n_warps = 4, unroll_vect = 4; + int64_t tiles_per_block = 4; + // int64_t n_serial_loop = tile_0 / unroll_vect / n_warps; + output_reference->split(1, tile_1); + output_reference->split(0, tile_0); + output_reference->reorder({{-2, 0}}); + output_reference->merge(0); + // [I0/tile_0 * I1/tile_1/tiles_per_block, tiles_per_block, tile_0, tile_1] + output_reference->split(0, tiles_per_block); + int64_t pos_bdimx = 0, tile0_pos = 2; + output_reference->axis(pos_bdimx)->parallelize(ParallelType::BIDx); + // output_reference->axis(pos_bdimx + 1)->parallelize(ParallelType::Unroll); + using Options = + scheduler_utils::BoundedDirectionalTransformPropagator::Options; + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_reference, + -1, + {input_reference}, + Options{}.propagateParallelType()); + // For fusion output, we just use TMA to store the entire tile back to global + // memory. There is no need to further schedule the output tensor. + output_reference->axis(tile0_pos)->parallelize(ParallelType::Bulk); + output_reference->axis(tile0_pos + 1)->parallelize(ParallelType::Bulk); + + fusion->printMath(); + + for (auto output_smem_cache : output_smem_tvs) { + // [BIDx, 32', 32] + output_smem_cache->setAllocationDomain( + output_smem_cache->getLoopDomain(), true); + // [BDIx, tile_0, tile_1] -> [BDIx, tile_0/unroll_vect, unroll_vect, tile_1] + output_smem_cache->split(tile0_pos, unroll_vect); + // [BDIx, tile_0/unroll_vect/n_warps, n_warps, unroll_vect, tile_1] + output_smem_cache->split(tile0_pos, n_warps); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_smem_cache, -1, {input_reference}); + output_smem_cache->axis(tile0_pos + 1)->parallelize(ParallelType::TIDy); + output_smem_cache->axis(tile0_pos + 3)->parallelize(ParallelType::TIDx); + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + output_smem_cache, + -1, + {input_smem_tvs.at(0)}, + Options{}.propagateParallelType()); + output_smem_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Unroll); + output_reg_cache->axis(tile0_pos + 2)->parallelize(ParallelType::Vectorize); + output_reg_cache->setAllocationDomain( + output_reg_cache->getLoopDomain(), true); + } + for (auto input_smem_cache : input_smem_tvs) { + // Schedule the memory format for 128 byte swizzle + // After backward propagation and reorder: + // [BIDx, tile_1, tile_0/unroll_vect/n_warps, n_warps, unroll_vect] + // = [BIDx, 32, 2, 4, 4] + int64_t tile1_pos = 2; + input_smem_cache->reorder({{-1, 2}}); + + std::cout << "input_smem_cache after reorder: " + << input_smem_cache->toString() << std::endl; + + // Split tile_1 (32) by 8 to create first dimension of size 8 + // [BIDx, 32, 2, 4, 4] -> [BIDx, 4, 8, 2, 4, 4] + input_smem_cache->split(tile1_pos, 8); + + // Merge the 2×4 to create second dimension of size 8 + // [BIDx, tilesPerCTA, 4, 8, 2, 4, 4] -> [BIDx, tilesPerCTA, 4, 8, 8, 4] + input_smem_cache->merge(tile1_pos + 2, tile1_pos + 3); + + std::cout << "input_smem_cache before swizzle: " + << input_smem_cache->toString() << std::endl; + + // Swizzle the two 8's at positions 2 and 3 + // [BIDx, tilesPerCTA, 4, 8, 8, 4] with XOR swizzle on dimensions 2 and 3 + input_smem_cache->swizzle(SwizzleType::XOR, tile1_pos + 1, tile1_pos + 2); + + // Set allocation domain to match the swizzled layout + input_smem_cache->setAllocationDomain( + input_smem_cache->getLoopDomain(), true); + + // Parallelize all dimensions as Bulk for TMA + input_smem_cache->axis(tile1_pos)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 1)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 2)->parallelize(ParallelType::Bulk); + input_smem_cache->axis(tile1_pos + 3)->parallelize(ParallelType::Bulk); + // [BIDx, Bulk, Bulk, Bulk, Bulk] + } + + inlineMost(); + + fusion->print(); +} + void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { + if (tparams->use_tma_store || tparams->use_tma_load) { + return scheduleTransposeTMA(fusion, tparams); + } FusionGuard fg(fusion); // Make sure we don't have global memory set on intermediate tensors from diff --git a/csrc/scheduler/transpose_heuristic.h b/csrc/scheduler/transpose_heuristic.h index f4de63bdd60..29fc49d62cf 100644 --- a/csrc/scheduler/transpose_heuristic.h +++ b/csrc/scheduler/transpose_heuristic.h @@ -51,6 +51,16 @@ class TransposeParams : public HeuristicParams { // Tile size for the inner most dim of tensors in the second group int64_t tile_size2 = getDefaultTileSize(); + // Use Hopper+ TMA (cp.async.bulk.tensor) for group2 cached-input loads + // (global -> shared). This is intentionally conservative and only affects + // the shared-memory staging tensors used for the thread-binding swap. + bool use_tma_load = false; + + // Use Hopper+ TMA (cp.async.bulk.tensor) for group1 cached-output stores + // (shared -> global). When enabled with use_tma_load, implements the full + // bank-conflict-free transpose pattern with 128-byte swizzle. + bool use_tma_store = false; + using HeuristicParams::HeuristicParams; // Warning: Does not check launch parameters! @@ -65,7 +75,9 @@ class TransposeParams : public HeuristicParams { other->dims_merged_with_2 == dims_merged_with_2 && other->vectorize_factor1 == vectorize_factor1 && other->vectorize_factor2 == vectorize_factor2 && - other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2; + other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2 && + other->use_tma_load == use_tma_load && + other->use_tma_store == use_tma_store; return attr_equal; } @@ -76,6 +88,8 @@ class TransposeParams : public HeuristicParams { << " BlckX: " << lparams.bdimx() << "\n"; ss << " input tile size: " << tile_size1 << "\n"; ss << " output tile size: " << tile_size2 << "\n"; + ss << " use_tma_load: " << use_tma_load << "\n"; + ss << " use_tma_store: " << use_tma_store << "\n"; int64_t elements_per_tile = tile_size1 * tile_size2; ss << " elements per tile: " << elements_per_tile << "\n"; int64_t elements_per_thread = elements_per_tile / lparams.bdimx(); @@ -146,7 +160,9 @@ class TransposeParams : public HeuristicParams { vectorize_factor1, vectorize_factor2, tile_size1, - tile_size2); + tile_size2, + use_tma_load, + use_tma_store); } std::unique_ptr clone() const override { diff --git a/doc/dev/transpose_access_map.md b/doc/dev/transpose_access_map.md new file mode 100644 index 00000000000..66c61339e8d --- /dev/null +++ b/doc/dev/transpose_access_map.md @@ -0,0 +1,159 @@ +# Transpose 32x32 Access Maps (All Warps Combined) + +Cells show tids (3 digits); empty cells are `...`. + +**Summary** +- No bank conflicts: within each warp, the swizzle permutes lanes so each thread lands on a distinct bank for the shared accesses shown in `write to smem` and `regs (from smem read)`. +- Coalesced gmem: the input and output gmem plots show contiguous columns for each warp group, so global reads/writes are aligned and coalesced. + +## input gmem (read) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 033 ... ... ... 032 ... ... ... 035 ... ... ... 034 ... ... ... 037 ... ... ... 036 ... ... ... 039 ... ... ... 038 ... ... ... +005: 041 ... ... ... 040 ... ... ... 043 ... ... ... 042 ... ... ... 045 ... ... ... 044 ... ... ... 047 ... ... ... 046 ... ... ... +006: 049 ... ... ... 048 ... ... ... 051 ... ... ... 050 ... ... ... 053 ... ... ... 052 ... ... ... 055 ... ... ... 054 ... ... ... +007: 057 ... ... ... 056 ... ... ... 059 ... ... ... 058 ... ... ... 061 ... ... ... 060 ... ... ... 063 ... ... ... 062 ... ... ... +008: 066 ... ... ... 067 ... ... ... 064 ... ... ... 065 ... ... ... 070 ... ... ... 071 ... ... ... 068 ... ... ... 069 ... ... ... +009: 074 ... ... ... 075 ... ... ... 072 ... ... ... 073 ... ... ... 078 ... ... ... 079 ... ... ... 076 ... ... ... 077 ... ... ... +010: 082 ... ... ... 083 ... ... ... 080 ... ... ... 081 ... ... ... 086 ... ... ... 087 ... ... ... 084 ... ... ... 085 ... ... ... +011: 090 ... ... ... 091 ... ... ... 088 ... ... ... 089 ... ... ... 094 ... ... ... 095 ... ... ... 092 ... ... ... 093 ... ... ... +012: 099 ... ... ... 098 ... ... ... 097 ... ... ... 096 ... ... ... 103 ... ... ... 102 ... ... ... 101 ... ... ... 100 ... ... ... +013: 107 ... ... ... 106 ... ... ... 105 ... ... ... 104 ... ... ... 111 ... ... ... 110 ... ... ... 109 ... ... ... 108 ... ... ... +014: 115 ... ... ... 114 ... ... ... 113 ... ... ... 112 ... ... ... 119 ... ... ... 118 ... ... ... 117 ... ... ... 116 ... ... ... +015: 123 ... ... ... 122 ... ... ... 121 ... ... ... 120 ... ... ... 127 ... ... ... 126 ... ... ... 125 ... ... ... 124 ... ... ... +016: 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... +017: 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... +018: 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... +019: 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... +020: 037 ... ... ... 036 ... ... ... 039 ... ... ... 038 ... ... ... 033 ... ... ... 032 ... ... ... 035 ... ... ... 034 ... ... ... +021: 045 ... ... ... 044 ... ... ... 047 ... ... ... 046 ... ... ... 041 ... ... ... 040 ... ... ... 043 ... ... ... 042 ... ... ... +022: 053 ... ... ... 052 ... ... ... 055 ... ... ... 054 ... ... ... 049 ... ... ... 048 ... ... ... 051 ... ... ... 050 ... ... ... +023: 061 ... ... ... 060 ... ... ... 063 ... ... ... 062 ... ... ... 057 ... ... ... 056 ... ... ... 059 ... ... ... 058 ... ... ... +024: 070 ... ... ... 071 ... ... ... 068 ... ... ... 069 ... ... ... 066 ... ... ... 067 ... ... ... 064 ... ... ... 065 ... ... ... +025: 078 ... ... ... 079 ... ... ... 076 ... ... ... 077 ... ... ... 074 ... ... ... 075 ... ... ... 072 ... ... ... 073 ... ... ... +026: 086 ... ... ... 087 ... ... ... 084 ... ... ... 085 ... ... ... 082 ... ... ... 083 ... ... ... 080 ... ... ... 081 ... ... ... +027: 094 ... ... ... 095 ... ... ... 092 ... ... ... 093 ... ... ... 090 ... ... ... 091 ... ... ... 088 ... ... ... 089 ... ... ... +028: 103 ... ... ... 102 ... ... ... 101 ... ... ... 100 ... ... ... 099 ... ... ... 098 ... ... ... 097 ... ... ... 096 ... ... ... +029: 111 ... ... ... 110 ... ... ... 109 ... ... ... 108 ... ... ... 107 ... ... ... 106 ... ... ... 105 ... ... ... 104 ... ... ... +030: 119 ... ... ... 118 ... ... ... 117 ... ... ... 116 ... ... ... 115 ... ... ... 114 ... ... ... 113 ... ... ... 112 ... ... ... +031: 127 ... ... ... 126 ... ... ... 125 ... ... ... 124 ... ... ... 123 ... ... ... 122 ... ... ... 121 ... ... ... 120 ... ... ... +``` + +## write to smem + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +005: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +006: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +007: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +008: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +009: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +010: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +011: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +012: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +013: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +014: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +015: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +016: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +017: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +018: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +019: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +020: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +021: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +022: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +023: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +024: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +025: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +026: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +027: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +028: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +029: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +030: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +031: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +``` + +## regs (from smem read) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +001: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +002: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +003: 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 000 008 016 024 032 040 048 056 064 072 080 088 096 104 112 120 +004: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +005: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +006: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +007: 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 033 041 049 057 001 009 017 025 097 105 113 121 065 073 081 089 +008: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +009: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +010: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +011: 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 066 074 082 090 098 106 114 122 002 010 018 026 034 042 050 058 +012: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +013: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +014: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +015: 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 099 107 115 123 067 075 083 091 035 043 051 059 003 011 019 027 +016: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +017: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +018: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +019: 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 004 012 020 028 036 044 052 060 068 076 084 092 100 108 116 124 +020: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +021: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +022: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +023: 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 037 045 053 061 005 013 021 029 101 109 117 125 069 077 085 093 +024: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +025: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +026: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +027: 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 070 078 086 094 102 110 118 126 006 014 022 030 038 046 054 062 +028: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +029: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +030: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +031: 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 103 111 119 127 071 079 087 095 039 047 055 063 007 015 023 031 +``` + +## output gmem (write) + +``` + 000 001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 +000: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +001: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +002: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +003: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +004: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +005: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +006: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +007: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +008: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +009: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +010: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +011: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +012: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +013: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +014: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +015: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +016: 000 ... ... ... 001 ... ... ... 002 ... ... ... 003 ... ... ... 004 ... ... ... 005 ... ... ... 006 ... ... ... 007 ... ... ... +017: 008 ... ... ... 009 ... ... ... 010 ... ... ... 011 ... ... ... 012 ... ... ... 013 ... ... ... 014 ... ... ... 015 ... ... ... +018: 016 ... ... ... 017 ... ... ... 018 ... ... ... 019 ... ... ... 020 ... ... ... 021 ... ... ... 022 ... ... ... 023 ... ... ... +019: 024 ... ... ... 025 ... ... ... 026 ... ... ... 027 ... ... ... 028 ... ... ... 029 ... ... ... 030 ... ... ... 031 ... ... ... +020: 032 ... ... ... 033 ... ... ... 034 ... ... ... 035 ... ... ... 036 ... ... ... 037 ... ... ... 038 ... ... ... 039 ... ... ... +021: 040 ... ... ... 041 ... ... ... 042 ... ... ... 043 ... ... ... 044 ... ... ... 045 ... ... ... 046 ... ... ... 047 ... ... ... +022: 048 ... ... ... 049 ... ... ... 050 ... ... ... 051 ... ... ... 052 ... ... ... 053 ... ... ... 054 ... ... ... 055 ... ... ... +023: 056 ... ... ... 057 ... ... ... 058 ... ... ... 059 ... ... ... 060 ... ... ... 061 ... ... ... 062 ... ... ... 063 ... ... ... +024: 064 ... ... ... 065 ... ... ... 066 ... ... ... 067 ... ... ... 068 ... ... ... 069 ... ... ... 070 ... ... ... 071 ... ... ... +025: 072 ... ... ... 073 ... ... ... 074 ... ... ... 075 ... ... ... 076 ... ... ... 077 ... ... ... 078 ... ... ... 079 ... ... ... +026: 080 ... ... ... 081 ... ... ... 082 ... ... ... 083 ... ... ... 084 ... ... ... 085 ... ... ... 086 ... ... ... 087 ... ... ... +027: 088 ... ... ... 089 ... ... ... 090 ... ... ... 091 ... ... ... 092 ... ... ... 093 ... ... ... 094 ... ... ... 095 ... ... ... +028: 096 ... ... ... 097 ... ... ... 098 ... ... ... 099 ... ... ... 100 ... ... ... 101 ... ... ... 102 ... ... ... 103 ... ... ... +029: 104 ... ... ... 105 ... ... ... 106 ... ... ... 107 ... ... ... 108 ... ... ... 109 ... ... ... 110 ... ... ... 111 ... ... ... +030: 112 ... ... ... 113 ... ... ... 114 ... ... ... 115 ... ... ... 116 ... ... ... 117 ... ... ... 118 ... ... ... 119 ... ... ... +031: 120 ... ... ... 121 ... ... ... 122 ... ... ... 123 ... ... ... 124 ... ... ... 125 ... ... ... 126 ... ... ... 127 ... ... ... +``` diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 3d0bb16b87a..218a7a1e27c 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1409,4 +1409,181 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); } +TEST_F(TransposeTest, Tma) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::BFloat16; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = add(tv1, tv1); + auto tv3 = transpose(tv2, 0, 1); + auto tv4 = mul(tv3, tv3); + auto tv5 = castOp(dtype, tv4); + fusion.addOutput(tv5); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0}); + testValidate(executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); +} + +TEST_F(TransposeTest, TmaTransposeSimple) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::Float; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + fusion.addOutput(tv1); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + bool auto_schedule = false; + if (std::getenv("USE_AUTO") != nullptr) { + auto_schedule = true; + } + if (auto_schedule) { + // H100, non-tma, 82%, tma-256, 71%, tma-128, 78%, tma-128-2tilesPerCTA, 80% + // non-tma, GB200, 62%, 2.18 ms, vect = 4, unroll = 2 + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0}); + testValidate( + executor_cache.fusion(), outputs, {input0}, __LINE__, __FILE__); + } else { + auto input_cache = tv0->cacheAfter(); + auto output_cache = tv1->cacheBefore(); + input_cache->setMemoryType(MemoryType::Shared); + + int64_t tile_size1 = 32, tile_size2 = 32; + bool use_propagate_schedule = false; + + if (use_propagate_schedule) { + // Propagate-based schedule from a single reference. + auto reference1 = tv1; + // [i1, i2] -> [i1, i2/tile1, tile1] + reference1->split(1, tile_size1); + reference1->reorder({{2, -1}}); + reference1->split(0, tile_size2); + reference1->reorder({{1, -1}}); + reference1->merge(0); + + reference1->split(0, 1); + // [r.., merged_dim, 1, tile1, tile2] + + // parallelize non-tile dimensions + reference1->axis(1)->parallelize(ParallelType::Unswitch); + reference1->axis(0)->parallelize(ParallelType::BIDx); + // [r.., BIDx, Unswitch, tile1, tile2] + + // Propagate transformations so far to the entire DAG + TransformPropagator propagator(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); + + scheduler_utils::parallelizeAllLike( + reference1, + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{}, + /*propagate_padding=*/true, + /*parallelize_inputs_on_did=*/true); + } else { + // Group-1 tensors: output side uses [y, x] layout. + for (auto tv : {output_cache, tv1}) { + // [y, x] -> [y/tile_size2, tile_size2, x/tile_size1, tile_size1] + tv->split(1, tile_size1); + tv->split(0, tile_size2); + // [x/tile_size1, y/tile_size2, tile_size1, tile_size2] + tv->reorder({{0, 1}, {1, 3}, {2, 0}, {3, 2}}); + std::cout << tv->toString() << std::endl; + // [x/tile_size1 * y/tile_size2, tile_size1, tile_size2] + tv->merge(0); + tv->split(0, 1); + tv->axis(1)->parallelize(ParallelType::Unswitch); + tv->axis(0)->parallelize(ParallelType::BIDx); + std::cout << tv->toString() << std::endl; + } + // Group-2 tensors: input side uses [x, y] layout. + for (auto tv : {tv0, input_cache}) { + // [x, y] -> [x/tile_size1, tile_size1, y/tile_size2, tile_size2] + tv->split(1, tile_size2); + tv->split(0, tile_size1); + // [x/tile_size1, y/tile_size2, tile_size1, tile_size2] + tv->reorder({{1, 2}, {2, 1}}); + std::cout << tv->toString() << std::endl; + // [x/tile_size1 * y/tile_size2, tile_size1, tile_size2] + tv->merge(0); + tv->split(0, 1); + tv->axis(1)->parallelize(ParallelType::Unswitch); + tv->axis(0)->parallelize(ParallelType::BIDx); + std::cout << tv->toString() << std::endl; + } + } + // Print to verify the loop domain matches the expected transform order. + std::cout << output_cache->toString() << std::endl; + ////////////////////////////// + // Step 3: Schedule group 2 // + ////////////////////////////// + int64_t pos = 2; + int64_t vectorize_factor = 4, threads_per_block = 128; + // schedule input cache + // [BIDx, Unswitch, tile_size1, tile_size2] + input_cache->split(3, vectorize_factor); + // [BIDx, Unswitch, tile_size1, tile_size2/vectorize_factor, + // vectorize_factor] + input_cache->split(2, vectorize_factor); + // [BIDx, Unswitch, tile_size1/vectorize_factor, vectorize_factor, + // tile_size2/vectorize_factor, vectorize_factor] + input_cache->swizzle(SwizzleType::XOR, 2, 4); + input_cache->merge(2); + input_cache->merge(2); + input_cache->split(2, threads_per_block); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + input_cache->setAllocationDomain(input_cache->getLoopDomain(), true); + input_cache->axis(2)->parallelize(ParallelType::Unroll); + input_cache->axis(3)->parallelize(ParallelType::TIDx); + input_cache->axis(4)->parallelize(ParallelType::Vectorize); + // [BIDx, Bulk, Bulk, Bulk, Bulk] + // Vectorize/Unroll for group 2. Only the shared cache is vectorized. + + for (auto tv : {tv0}) { + tv->merge(pos); + tv->split(pos, vectorize_factor); + tv->split(pos, threads_per_block); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + tv->axis(2)->parallelize(ParallelType::Unroll); + tv->axis(3)->parallelize(ParallelType::TIDx); + } + ////////////////////////////// + // Step 4: Schedule group 1 // + ////////////////////////////// + // Vectorize/Unroll for group 1. Only the output is vectorized. + for (auto tv : {output_cache, tv1}) { + tv->reorder({{-2, -1}}); + // [..., tile2, tile1] + tv->merge(pos); + tv->split(pos, vectorize_factor); + tv->split(pos, threads_per_block); + tv->axis(2)->parallelize(ParallelType::Unroll); + tv->axis(3)->parallelize(ParallelType::TIDx); + if (tv == tv1) { + tv->axis(4)->parallelize(ParallelType::Vectorize); + } + } + inlineMost(); + fusion.print(); + KernelExecutor ke; + ke.compile(&fusion, {input0}); + auto outputs = ke.run({input0}); + testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__); + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_tutorial.cpp b/tests/cpp/test_tutorial.cpp index 1acb7209f91..4d3e5fd5f76 100644 --- a/tests/cpp/test_tutorial.cpp +++ b/tests/cpp/test_tutorial.cpp @@ -1497,6 +1497,7 @@ TEST_F(Tutorial, TMABankConflictFreeTranspose) { output->axis(1)->parallelize(ParallelType::Bulk); output->axis(2)->parallelize(ParallelType::Bulk); // [BIDx, Bulk, Bulk] + fusion.print(); // output_smem_cache and output_reg_cache are scheduled in the same way. // We use each warp to load one column of input_smem_cache. We vectorize